-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Yibing Liu
committed
Jun 1, 2017
1 parent
367e123
commit 504b15c
Showing
1 changed file
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
## This is a prototype of ctc beam search decoder | ||
|
||
import copy | ||
import random | ||
import numpy as np | ||
|
||
# vocab = English characters + blank + space | ||
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)] | ||
|
||
vocab = [chr(97), chr(98), chr(99), chr(100)]+[' ', '-'] | ||
|
||
def ids_str2list(ids_str): | ||
ids_str = ids_str.split(' ') | ||
ids_list = [int(elem) for elem in ids_str] | ||
return ids_list | ||
|
||
def ids_list2str(ids_list): | ||
ids_str = [str(elem) for elem in ids_list] | ||
ids_str = ' '.join(ids_str) | ||
return ids_str | ||
|
||
def ctc_beam_search_decoder( | ||
input_probs_matrix, | ||
beam_size, | ||
lang_model=None, | ||
name=None, | ||
alpha=1.0, | ||
beta=1.0, | ||
blank_id=0, | ||
space_id=1, | ||
num_results_per_sample=None): | ||
|
||
''' | ||
beam search decoder for CTC-trained network, called outside of the recurrent group. | ||
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. | ||
''' | ||
if num_results_per_sample is None: | ||
num_results_per_sample = beam_size | ||
assert num_results_per_sample <= beam_size | ||
|
||
max_time_steps = len(input_probs_matrix) | ||
assert max_time_steps > 0 | ||
|
||
vocab_dim = len(input_probs_matrix[0]) | ||
assert blank_id < vocab_dim | ||
assert space_id < vocab_dim | ||
|
||
# initialize | ||
start_id = -1 | ||
prefix_set_prev = {str(start_id):1.0} | ||
probs_b, probs_nb = {str(start_id):1.0}, {str(start_id):0.0} | ||
|
||
# extend prefix in loop | ||
for time_step in range(max_time_steps): | ||
print "\ntime_step = %d" % (time_step+1) | ||
prefix_set_next = {} | ||
probs_b_cur, probs_nb_cur = {}, {} | ||
for (l, prob_) in prefix_set_prev.items(): | ||
print "l = %s\t%f" % (l, prob_) | ||
prob = input_probs_matrix[time_step] | ||
|
||
# convert ids in string to list | ||
ids_list = ids_str2list(l) | ||
end_id = ids_list[-1] | ||
if not probs_b_cur.has_key(l): | ||
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 | ||
|
||
# extend prefix | ||
for c in range(0, vocab_dim): | ||
if c == blank_id: | ||
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) | ||
else: | ||
l_plus = l + ' ' + str(c) | ||
if not probs_b_cur.has_key(l_plus): | ||
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 | ||
|
||
if c == end_id: | ||
probs_nb_cur[l_plus] += prob[c] * probs_b[l] | ||
probs_nb_cur[l] += prob[c] * probs_nb[l] | ||
elif c == space_id: | ||
lm = 1 if lang_model is None else np.power(lang_model(ids_list), alpha) | ||
probs_nb_cur[l_plus] += lm * prob[c] * (probs_b[l]+probs_nb[l]) | ||
else: | ||
probs_nb_cur[l_plus] += prob[c] * (probs_b[l]+probs_nb[l]) | ||
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus] | ||
|
||
print "l_plus: %s\t b=%f\tnb=%f\tP=%f" % (l_plus, probs_b_cur[l_plus], probs_nb_cur[l_plus], prefix_set_next[l_plus]) | ||
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l] | ||
print "l: %s\t b=%f\tnb=%f\tP=%f" % (l, probs_b_cur[l], probs_nb_cur[l], prefix_set_next[l]) | ||
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(probs_nb_cur) | ||
|
||
prefix_set_prev = sorted(prefix_set_next.iteritems(), key = lambda asd:asd[1], reverse=True) | ||
|
||
if beam_size < len(prefix_set_prev): | ||
prefix_set_prev = prefix_set_prev[:beam_size] | ||
prefix_set_prev = dict(prefix_set_prev) | ||
|
||
beam_result = [] | ||
for (seq, prob) in prefix_set_prev.items(): | ||
if prob > 0.0: | ||
ids_list = ids_str2list(seq) | ||
log_prob = np.log(prob) | ||
beam_result.append([log_prob, ids_list[1:]]) | ||
|
||
beam_result = sorted(beam_result, key = lambda asd:asd[0], reverse=True) | ||
if num_results_per_sample < beam_size: | ||
beam_result = beam_result[:num_results_per_sample] | ||
return beam_result | ||
|
||
def language_model(input): | ||
# TODO | ||
return random.uniform(0, 1) | ||
|
||
def ctc_net(input, size_vocab): | ||
size = len(vocab) | ||
# prob = np.array([random.uniform(0, 1) for i in range(0, size)]) | ||
prob = np.array([1.0 for i in range(0, size)]) | ||
prob = prob/prob.sum() | ||
return prob | ||
|
||
def main(): | ||
|
||
input_probs_matrix = [[0.1, 0.3, 0.6], | ||
[0.2, 0.1, 0.7], | ||
[0.5, 0.2, 0.3]] | ||
|
||
prob_matrix = [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], | ||
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], | ||
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], | ||
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], | ||
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], | ||
# Random entry added in at time=5 | ||
#[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] | ||
] | ||
|
||
beam_result = ctc_beam_search_decoder( | ||
input_probs_matrix=prob_matrix, | ||
beam_size=2, | ||
blank_id=5, | ||
) | ||
def ids2str(ids_list): | ||
ids_str = '' | ||
for ids in ids_list: | ||
ids_str += vocab[ids] | ||
return ids_str | ||
|
||
print "\nbeam search output:" | ||
for result in beam_result: | ||
print result[0], ids2str(result[1]) | ||
|
||
if __name__ == '__main__': | ||
main() |