Skip to content

Commit

Permalink
final refining on old data provider: enable pruning & add evaluation …
Browse files Browse the repository at this point in the history
…& code cleanup
  • Loading branch information
Yibing Liu committed Jun 18, 2017
1 parent 0fa063e commit 08203ee
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 72 deletions.
84 changes: 61 additions & 23 deletions deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os
from itertools import groupby
import numpy as np
import copy
import kenlm
import multiprocessing

Expand Down Expand Up @@ -73,25 +72,40 @@ def word_count(self, sentence):
return len(words)

# execute evaluation
def __call__(self, sentence):
def __call__(self, sentence, log=False):
"""
Evaluation function
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence)
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
if log == False:
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score


def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
'''
Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in
the order of probabilities. The implementation is based on Prefix Beam
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
the descending order of probabilities. The implementation is based on Prefix
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified.
:param probs_seq: 2-D list with length num_time_steps, each element
Expand All @@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: Decoding log probability and result string.
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list
'''
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
raise ValueError("probs dimension mismatched with vocabulary")
num_time_steps = len(probs_seq)

# blank_id check
Expand All @@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}

## extend prefix in loop
for time_step in range(num_time_steps):
for time_step in xrange(num_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
prob = probs_seq[time_step]
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
cutoff_len = len(prob_idx)
#If pruning is enabled
if (cutoff_prob < 1.0):
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0
cum_prob = 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]

for l in prefix_set_prev:
prob = probs_seq[time_step]
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

# extend prefix by travering vocabulary
for c in range(0, probs_dim):
# extend prefix by travering prob_idx
for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]

if c == blank_id:
probs_b_cur[l] += prob[c] * (
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
Expand All @@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

if new_char == last_char:
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob[c] * (
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob[c] * (
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
Expand Down Expand Up @@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
beam_size,
vocabulary,
blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None,
num_processes=None):
'''
Expand All @@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probability and result string.
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list
'''
Expand All @@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, None, nproc)
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))

pool.close()
Expand Down
Loading

0 comments on commit 08203ee

Please sign in to comment.