From 3ecf1ad4f5f2d61d31184c84be7ccfffc385611f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 21:33:00 +0800 Subject: [PATCH 1/8] Decouple ext scorer init & inference & decoding for the convenience of tuning --- examples/librispeech/run_tune.sh | 2 +- infer.py | 32 ++++--- model_utils/model.py | 158 +++++++++++++++++-------------- test.py | 31 ++++-- tools/tune.py | 104 +++++--------------- 5 files changed, 153 insertions(+), 174 deletions(-) diff --git a/examples/librispeech/run_tune.sh b/examples/librispeech/run_tune.sh index c3695d1cb2c..9fc9cbb9d95 100644 --- a/examples/librispeech/run_tune.sh +++ b/examples/librispeech/run_tune.sh @@ -7,7 +7,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \ python -u tools/tune.py \ --num_batches=-1 \ --batch_size=128 \ ---trainer_count=8 \ +--trainer_count=4 \ --beam_size=500 \ --num_proc_bsearch=12 \ --num_conv_layers=2 \ diff --git a/infer.py b/infer.py index b801c507b72..1539fbaaff4 100644 --- a/infer.py +++ b/infer.py @@ -90,18 +90,26 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - result_transcripts = ds2_model.infer_batch( - infer_data=infer_data, - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=args.num_proc_bsearch, - feeding_dict=data_generator.feeding) + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) + if args.decoding_method == "ctc_greedy": + ds2_model.logger.info("start inference ...") + result_transcripts = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + ds2_model.logger.info("start inference ...") + result_transcripts = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) error_rate_func = cer if args.error_rate_type == 'cer' else wer target_transcripts = [data[1] for data in infer_data] diff --git a/model_utils/model.py b/model_utils/model.py index 85d50053ee7..f6d3ef05979 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -173,43 +173,19 @@ def infer_loss_batch(self, infer_data): # run inference return self._loss_inferer.infer(input=infer_data) - def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, vocab_list, - language_model_path, num_processes, feeding_dict): - """Model inference. Infer the transcription for a batch of speech - utterances. + def infer_probs_batch(self, infer_data, feeding_dict): + """Infer the prob matrices for a batch of speech utterances. :param infer_data: List of utterances to infer, with each utterance consisting of a tuple of audio features and transcription text (empty string). :type infer_data: list - :param decoding_method: Decoding method name, 'ctc_greedy' or - 'ctc_beam_search'. - :param decoding_method: string - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param beam_size: Width for Beam search. - :type beam_size: int - :param cutoff_prob: Cutoff probability in pruning, - default 1.0, no pruning. - :type cutoff_prob: float - :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. - :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param language_model_path: Filepath for language model. - :type language_model_path: basestring|None - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int :param feeding_dict: Feeding is a map of field name and tuple index of the data that reader returns. :type feeding_dict: dict|list - :return: List of transcription texts. - :rtype: List of basestring + :return: List of 2-D probability matrix, and each consists of prob + vectors for one speech utterancce. + :rtype: List of matrix """ # define inferer if self._inferer == None: @@ -227,49 +203,91 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta, infer_results[start_pos[i]:start_pos[i + 1]] for i in xrange(0, len(adapted_infer_data)) ] - # run decoder + return probs_split + + def infer_batch_greedy(self, probs_split, vocab_list): + """ + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :return: List of transcription texts. + :rtype: List of basestring + """ results = [] - if decoding_method == "ctc_greedy": - # best path decode - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - elif decoding_method == "ctc_beam_search": - # initialize external scorer - if self._ext_scorer == None: - self._loaded_lm_path = language_model_path - self.logger.info("begin to initialize the external scorer " - "for decoding") - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) - self.logger.info("end initializing scorer. Start decoding ...") - else: - self._ext_scorer.reset_params(beam_alpha, beam_beta) - assert self._loaded_lm_path == language_model_path - # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=vocab_list) + results.append(output_transcription) + return results + + def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, + vocab_list): + """Initialize the external scorer. + + """ + if language_model_path != '': + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer") else: - raise ValueError("Decoding method [%s] is not supported." % - decoding_method) + self._ext_scorer = None + self.logger.info("no language model provided, " + "decoding by pure beam search without scorer.") + + def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Model inference. Infer the transcription for a batch of speech + utterances. + + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of basestring + """ + if self._ext_scorer != None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + # beam search decode + num_processes = min(num_processes, len(probs_split)) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) + + results = [result[0][1] for result in beam_search_results] return results def _adapt_feeding_dict(self, feeding_dict): diff --git a/test.py b/test.py index 5cf7664870a..24ce54a2be8 100644 --- a/test.py +++ b/test.py @@ -90,22 +90,33 @@ def evaluate(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_sum, len_refs, num_ins = 0.0, 0, 0 + ds2_model.logger.info("start evaluation ...") for infer_data in batch_reader(): - result_transcripts = ds2_model.infer_batch( + probs_split = ds2_model.infer_probs_batch( infer_data=infer_data, - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=args.num_proc_bsearch, feeding_dict=data_generator.feeding) + + if args.decoding_method == "ctc_greedy": + result_transcripts = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + result_transcripts = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) target_transcripts = [data[1] for data in infer_data] + for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors diff --git a/tools/tune.py b/tools/tune.py index b132331953e..83978be8d8d 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -13,9 +13,7 @@ import paddle.v2 as paddle import _init_paths from data_utils.data import DataGenerator -from decoders.swig_wrapper import Scorer -from decoders.swig_wrapper import ctc_beam_search_decoder_batch -from model_utils.model import deep_speech_v2_network +from model_utils.model import DeepSpeech2Model from utils.error_rate import char_errors, word_errors from utils.utility import add_arguments, print_arguments @@ -88,40 +86,7 @@ def tune(): augmentation_config='{}', specgram_type=args.specgram_type, num_threads=args.num_proc_data, - keep_transcription_text=True, - num_conv_layers=args.num_conv_layers) - - audio_data = paddle.layer.data( - name="audio_spectrogram", - type=paddle.data_type.dense_array(161 * 161)) - text_data = paddle.layer.data( - name="transcript_text", - type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) - seq_offset_data = paddle.layer.data( - name='sequence_offset', - type=paddle.data_type.integer_value_sequence(1)) - seq_len_data = paddle.layer.data( - name='sequence_length', - type=paddle.data_type.integer_value_sequence(1)) - index_range_datas = [] - for i in xrange(args.num_rnn_layers): - index_range_datas.append( - paddle.layer.data( - name='conv%d_index_range' % i, - type=paddle.data_type.dense_vector(6))) - - output_probs, _ = deep_speech_v2_network( - audio_data=audio_data, - text_data=text_data, - seq_offset_data=seq_offset_data, - seq_len_data=seq_len_data, - index_range_datas=index_range_datas, - dict_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_size=args.rnn_layer_size, - use_gru=args.use_gru, - share_rnn_weights=args.share_rnn_weights) + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.tune_manifest, @@ -129,35 +94,17 @@ def tune(): sortagrad=False, shuffle_method=None) - # load parameters - if not os.path.isfile(args.model_path): - raise IOError("Invaid model path: %s" % args.model_path) - parameters = paddle.parameters.Parameters.from_tar( - gzip.open(args.model_path)) + ds2_model = DeepSpeech2Model( + vocab_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_layer_size=args.rnn_layer_size, + use_gru=args.use_gru, + pretrained_model_path=args.model_path, + share_rnn_weights=args.share_rnn_weights) - inferer = paddle.inference.Inference( - output_layer=output_probs, parameters=parameters) # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - - # init logger - logger = logging.getLogger("") - logger.setLevel(level=logging.INFO) - # init external scorer - logger.info("begin to initialize the external scorer for tuning") - if not os.path.isfile(args.lang_model_path): - raise IOError("Invaid language model path: %s" % args.lang_model_path) - ext_scorer = Scorer( - alpha=args.alpha_from, - beta=args.beta_from, - model_path=args.lang_model_path, - vocabulary=vocab_list) - logger.info("language model: " - "is_character_based = %d," % ext_scorer.is_character_based() + - " max_order = %d," % ext_scorer.get_max_order() + - " dict_size = %d" % ext_scorer.get_dict_size()) - logger.info("end initializing scorer. Start tuning ...") - errors_func = char_errors if args.error_rate_type == 'cer' else word_errors # create grid for search cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) @@ -168,37 +115,32 @@ def tune(): err_sum = [0.0 for i in xrange(len(params_grid))] err_ave = [0.0 for i in xrange(len(params_grid))] num_ins, len_refs, cur_batch = 0, 0, 0 + # initialize external scorer + ds2_model.init_ext_scorer(args.alpha_from, args.beta_from, + args.lang_model_path, vocab_list) ## incremental tuning parameters over multiple batches + ds2_model.logger.info("start tuning ...") for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - infer_results = inferer.infer(input=infer_data, - feeding=data_generator.feeding) - start_pos = [0] * (len(infer_data) + 1) - for i in xrange(len(infer_data)): - start_pos[i + 1] = start_pos[i] + infer_data[i][3][0] - probs_split = [ - infer_results[start_pos[i]:start_pos[i + 1]] - for i in xrange(0, len(infer_data)) - ] - + probs_split = ds2_model.infer_probs_batch( + infer_data=infer_data, + feeding_dict=data_generator.feeding) target_transcripts = [ data[1] for data in infer_data ] num_ins += len(target_transcripts) # grid search for index, (alpha, beta) in enumerate(params_grid): - # reset alpha & beta - ext_scorer.reset_params(alpha, beta) - beam_search_results = ctc_beam_search_decoder_batch( + result_transcripts = ds2_model.infer_batch_beam_search( probs_split=probs_split, - vocabulary=vocab_list, + beam_alpha=alpha, + beam_beta=beta, beam_size=args.beam_size, - num_processes=args.num_proc_bsearch, cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n, - ext_scoring_func=ext_scorer, ) + vocab_list=vocab_list, + num_processes=args.num_proc_bsearch) - result_transcripts = [res[0][1] for res in beam_search_results] for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) err_sum[index] += errors @@ -235,7 +177,7 @@ def tune(): % (cur_batch, "%.3f" % params_grid[min_index][0], "%.3f" % params_grid[min_index][1])) - logger.info("finish tuning") + ds2_model.logger.info("finish tuning") def main(): From 8ae25aebe17d4bb672092e74ae6c31cdae692775 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 21:53:43 +0800 Subject: [PATCH 2/8] Add more comments in init_ext_scorer() --- model_utils/model.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/model_utils/model.py b/model_utils/model.py index f6d3ef05979..70ba7bb93c4 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -226,6 +226,17 @@ def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, vocab_list): """Initialize the external scorer. + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param language_model_path: Filepath for language model. If it is + empty, the external scorer will be set to + None, and the decoding method will be pure + beam search without scorer. + :type language_model_path: basestring|None + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list """ if language_model_path != '': self.logger.info("begin to initialize the external scorer " From 10d337097005ac56f57f429266543d892c36a64d Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 22:50:19 +0800 Subject: [PATCH 3/8] Remove redundant lines in tune.py --- tools/tune.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/tune.py b/tools/tune.py index 83978be8d8d..923e6c3c32a 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -70,9 +70,6 @@ args = parser.parse_args() -logging.basicConfig( - format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') - def tune(): """Tune parameters alpha and beta incrementally.""" if not args.num_alphas >= 0: From 3a36c8a69ea50200439794f7cb87a97267044887 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 23:22:08 +0800 Subject: [PATCH 4/8] Adapt demo_server to the decoupling in infer_batch() --- deploy/demo_server.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/deploy/demo_server.py b/deploy/demo_server.py index d64f9f01551..53be16f77a7 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -160,22 +160,30 @@ def start_server(): vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") - - result_transcript = ds2_model.infer_batch( + probs_split = ds2_model.infer_probs_batch( infer_data=[feature], - decoding_method=args.decoding_method, - beam_alpha=args.alpha, - beam_beta=args.beta, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - cutoff_top_n=args.cutoff_top_n, - vocab_list=vocab_list, - language_model_path=args.lang_model_path, - num_processes=1, feeding_dict=data_generator.feeding) + + if args.decoding_method == "ctc_greedy": + result_transcript = ds2_model.infer_batch_greedy( + probs_split=probs_split, + vocab_list=vocab_list) + else: + result_transcript = ds2_model.infer_batch_beam_search( + probs_split=probs_split, + beam_alpha=args.alpha, + beam_beta=args.beta, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + cutoff_top_n=args.cutoff_top_n, + vocab_list=vocab_list, + num_processes=1) return result_transcript[0] # warming up with utterrances sampled from Librispeech From 66a39088180052813c80a33919b85ff976b6e076 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 12 Jan 2018 23:33:03 +0800 Subject: [PATCH 5/8] Adjust the order of scorer init & probs infer in infer.py --- infer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/infer.py b/infer.py index 1539fbaaff4..5dd9b406d1f 100644 --- a/infer.py +++ b/infer.py @@ -90,17 +90,18 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] + if args.decoding_method == "ctc_beam_search": + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + + ds2_model.logger.info("start inference ...") probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - ds2_model.logger.info("start inference ...") result_transcripts = ds2_model.infer_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) - ds2_model.logger.info("start inference ...") result_transcripts = ds2_model.infer_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, From 6c2cf40ce1abbd60a775ee0272bab48836ff9848 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sat, 13 Jan 2018 11:27:40 +0800 Subject: [PATCH 6/8] Rename prefix 'infer_batch' to 'decode_batch' --- deploy/demo_server.py | 4 ++-- infer.py | 4 ++-- model_utils/model.py | 14 +++++++------- test.py | 4 ++-- tools/tune.py | 2 +- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/deploy/demo_server.py b/deploy/demo_server.py index 53be16f77a7..eca13dcea8d 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -171,11 +171,11 @@ def file_to_transcript(filename): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcript = ds2_model.infer_batch_greedy( + result_transcript = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcript = ds2_model.infer_batch_beam_search( + result_transcript = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/infer.py b/infer.py index 5dd9b406d1f..ff45a5dc864 100644 --- a/infer.py +++ b/infer.py @@ -98,11 +98,11 @@ def infer(): probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/model_utils/model.py b/model_utils/model.py index 70ba7bb93c4..a8283fae451 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -205,8 +205,9 @@ def infer_probs_batch(self, infer_data, feeding_dict): ] return probs_split - def infer_batch_greedy(self, probs_split, vocab_list): - """ + def decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix @@ -256,11 +257,10 @@ def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, self.logger.info("no language model provided, " "decoding by pure beam search without scorer.") - def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Model inference. Infer the transcription for a batch of speech - utterances. + def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. diff --git a/test.py b/test.py index 24ce54a2be8..a82893c03bb 100644 --- a/test.py +++ b/test.py @@ -102,11 +102,11 @@ def evaluate(): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/tools/tune.py b/tools/tune.py index 923e6c3c32a..d8e28c58a5e 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -128,7 +128,7 @@ def tune(): num_ins += len(target_transcripts) # grid search for index, (alpha, beta) in enumerate(params_grid): - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=alpha, beam_beta=beta, From dd2588c96b4589284d73528a3a8566875edc6cc4 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 15 Jan 2018 14:17:07 +0800 Subject: [PATCH 7/8] Merge two if statements in infer --- infer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/infer.py b/infer.py index ff45a5dc864..4a5f8cb05e9 100644 --- a/infer.py +++ b/infer.py @@ -90,18 +90,19 @@ def infer(): # decoders only accept string encoded in utf-8 vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] - if args.decoding_method == "ctc_beam_search": - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) - - ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, - feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": + ds2_model.logger.info("start inference ...") + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: + ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, + vocab_list) + ds2_model.logger.info("start inference ...") + probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, From 7c6fa642cda67554c7731c5e38e955fd7e9b0afc Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 15 Jan 2018 14:34:59 +0800 Subject: [PATCH 8/8] Rename infer_probs_batch to infer_batch_probs --- deploy/demo_server.py | 2 +- infer.py | 4 ++-- model_utils/model.py | 2 +- test.py | 2 +- tools/tune.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deploy/demo_server.py b/deploy/demo_server.py index eca13dcea8d..1cafb7a58f6 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -166,7 +166,7 @@ def start_server(): # prepare ASR inference handler def file_to_transcript(filename): feature = data_generator.process_utterance(filename, "") - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=[feature], feeding_dict=data_generator.feeding) diff --git a/infer.py b/infer.py index 4a5f8cb05e9..f4d75685b90 100644 --- a/infer.py +++ b/infer.py @@ -92,7 +92,7 @@ def infer(): if args.decoding_method == "ctc_greedy": ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + probs_split = ds2_model.infer_batch_probs(infer_data=infer_data, feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, @@ -101,7 +101,7 @@ def infer(): ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, vocab_list) ds2_model.logger.info("start inference ...") - probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, + probs_split = ds2_model.infer_batch_probs(infer_data=infer_data, feeding_dict=data_generator.feeding) result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, diff --git a/model_utils/model.py b/model_utils/model.py index a8283fae451..4b3764bf2f0 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -173,7 +173,7 @@ def infer_loss_batch(self, infer_data): # run inference return self._loss_inferer.infer(input=infer_data) - def infer_probs_batch(self, infer_data, feeding_dict): + def infer_batch_probs(self, infer_data, feeding_dict): """Infer the prob matrices for a batch of speech utterances. :param infer_data: List of utterances to infer, with each utterance diff --git a/test.py b/test.py index a82893c03bb..e5a3346a0ac 100644 --- a/test.py +++ b/test.py @@ -97,7 +97,7 @@ def evaluate(): errors_sum, len_refs, num_ins = 0.0, 0, 0 ds2_model.logger.info("start evaluation ...") for infer_data in batch_reader(): - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=infer_data, feeding_dict=data_generator.feeding) diff --git a/tools/tune.py b/tools/tune.py index d8e28c58a5e..da785189f1d 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -120,7 +120,7 @@ def tune(): for infer_data in batch_reader(): if (args.num_batches >= 0) and (cur_batch >= args.num_batches): break - probs_split = ds2_model.infer_probs_batch( + probs_split = ds2_model.infer_batch_probs( infer_data=infer_data, feeding_dict=data_generator.feeding) target_transcripts = [ data[1] for data in infer_data ]