Skip to content

Commit

Permalink
adapt to the new data provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibing Liu committed Jun 20, 2017
1 parent 06f272a commit 5c4751e
Show file tree
Hide file tree
Showing 4 changed files with 813 additions and 31 deletions.
265 changes: 247 additions & 18 deletions deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Contains various CTC decoder."""
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
from itertools import groupby
import numpy as np
import kenlm
import multiprocessing


def ctc_best_path_decode(probs_seq, vocabulary):
Expand Down Expand Up @@ -36,24 +39,250 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return ''.join([vocabulary[index] for index in index_list])


def ctc_decode(probs_seq, vocabulary, method):
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
class Scorer(object):
"""External defined scorer to evaluate a sentence in beam search
decoding, consisting of language model and word count.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param alpha: Parameter associated with language model.
:type alpha: float
:param beta: Parameter associated with word count.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""

def __init__(self, alpha, beta, model_path):
self._alpha = alpha
self._beta = beta
if not os.path.isfile(model_path):
raise IOError("Invaid language model path: %s" % model_path)
self._language_model = kenlm.LanguageModel(model_path)

# n-gram language model scoring
def language_model_score(self, sentence):
#log prob of last word
log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob)

# word insertion term
def word_count(self, sentence):
words = sentence.strip().split(' ')
return len(words)

# execute evaluation
def __call__(self, sentence, log=False):
"""Evaluation function, gathering all the scores.
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm = self.language_model_score(sentence)
word_cnt = self.word_count(sentence)
if log == False:
score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta)
else:
score = self._alpha * np.log(lm) \
+ self._beta * np.log(word_cnt)
return score


def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None,
nproc=False):
'''Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in
the descending order of probabilities. The implementation is based on Prefix
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned.
:param probs_seq: 2-D list with length num_time_steps, each element
is a list of normalized probabilities over vocabulary
and blank for one time step.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param method: Decoding method name, with options: "best_path".
:type method: basestring
:return: Decoding result string.
:rtype: baseline
"""
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list
'''
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
if method == "best_path":
return ctc_best_path_decode(probs_seq, vocabulary)
else:
raise ValueError("Decoding method [%s] is not supported.")
raise ValueError("probs dimension mismatched with vocabulary")
num_time_steps = len(probs_seq)

# blank_id check
probs_dim = len(probs_seq[0])
if not blank_id < probs_dim:
raise ValueError("blank_id shouldn't be greater than probs dimension")

# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_nproc().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer

## initialize
# the set containing selected prefixes
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}

## extend prefix in loop
for time_step in xrange(num_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
prob = probs_seq[time_step]
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
cutoff_len = len(prob_idx)
#If pruning is enabled
if (cutoff_prob < 1.0):
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0
cum_prob = 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
prob_idx = prob_idx[0:cutoff_len]

for l in prefix_set_prev:
if not prefix_set_next.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

# extend prefix by travering prob_idx
for index in xrange(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]

if c == blank_id:
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
new_char = vocabulary[c]
l_plus = l + new_char
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

if new_char == last_char:
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur

## store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)

beam_result = []
for (seq, prob) in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = np.log(prob)
beam_result.append([log_prob, result])

## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result


def ctc_beam_search_decoder_nproc(probs_split,
beam_size,
vocabulary,
blank_id=0,
cutoff_prob=1.0,
ext_scoring_func=None,
num_processes=None):
'''Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length batch_size, each element
is a 2-D list of probabilities can be used by
ctc_beam_search_decoder.
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list
'''
if num_processes is None:
num_processes = multiprocessing.cpu_count()
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")

# use global variable to pass the externnal scorer to beam search decoder
global ext_nproc_scorer
ext_nproc_scorer = ext_scoring_func
nproc = True

pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))

pool.close()
pool.join()
beam_search_results = []
for result in results:
beam_search_results.append(result.get())
return beam_search_results
Loading

0 comments on commit 5c4751e

Please sign in to comment.