Skip to content

Commit

Permalink
add ctc beam search decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibing Liu committed Jun 1, 2017
1 parent 367e123 commit 504b15c
Showing 1 changed file with 152 additions and 0 deletions.
152 changes: 152 additions & 0 deletions ctc_beam_search_decoder/ctc_beam_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
## This is a prototype of ctc beam search decoder

import copy
import random
import numpy as np

# vocab = English characters + blank + space
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)]

vocab = [chr(97), chr(98), chr(99), chr(100)]+[' ', '-']

def ids_str2list(ids_str):
ids_str = ids_str.split(' ')
ids_list = [int(elem) for elem in ids_str]
return ids_list

def ids_list2str(ids_list):
ids_str = [str(elem) for elem in ids_list]
ids_str = ' '.join(ids_str)
return ids_str

def ctc_beam_search_decoder(
input_probs_matrix,
beam_size,
lang_model=None,
name=None,
alpha=1.0,
beta=1.0,
blank_id=0,
space_id=1,
num_results_per_sample=None):

'''
beam search decoder for CTC-trained network, called outside of the recurrent group.
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873.
'''
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size

max_time_steps = len(input_probs_matrix)
assert max_time_steps > 0

vocab_dim = len(input_probs_matrix[0])
assert blank_id < vocab_dim
assert space_id < vocab_dim

# initialize
start_id = -1
prefix_set_prev = {str(start_id):1.0}
probs_b, probs_nb = {str(start_id):1.0}, {str(start_id):0.0}

# extend prefix in loop
for time_step in range(max_time_steps):
print "\ntime_step = %d" % (time_step+1)
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
for (l, prob_) in prefix_set_prev.items():
print "l = %s\t%f" % (l, prob_)
prob = input_probs_matrix[time_step]

# convert ids in string to list
ids_list = ids_str2list(l)
end_id = ids_list[-1]
if not probs_b_cur.has_key(l):
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

# extend prefix
for c in range(0, vocab_dim):
if c == blank_id:
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
else:
l_plus = l + ' ' + str(c)
if not probs_b_cur.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

if c == end_id:
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
probs_nb_cur[l] += prob[c] * probs_nb[l]
elif c == space_id:
lm = 1 if lang_model is None else np.power(lang_model(ids_list), alpha)
probs_nb_cur[l_plus] += lm * prob[c] * (probs_b[l]+probs_nb[l])
else:
probs_nb_cur[l_plus] += prob[c] * (probs_b[l]+probs_nb[l])
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus]

print "l_plus: %s\t b=%f\tnb=%f\tP=%f" % (l_plus, probs_b_cur[l_plus], probs_nb_cur[l_plus], prefix_set_next[l_plus])
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]
print "l: %s\t b=%f\tnb=%f\tP=%f" % (l, probs_b_cur[l], probs_nb_cur[l], prefix_set_next[l])
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(probs_nb_cur)

prefix_set_prev = sorted(prefix_set_next.iteritems(), key = lambda asd:asd[1], reverse=True)

if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)

beam_result = []
for (seq, prob) in prefix_set_prev.items():
if prob > 0.0:
ids_list = ids_str2list(seq)
log_prob = np.log(prob)
beam_result.append([log_prob, ids_list[1:]])

beam_result = sorted(beam_result, key = lambda asd:asd[0], reverse=True)
if num_results_per_sample < beam_size:
beam_result = beam_result[:num_results_per_sample]
return beam_result

def language_model(input):
# TODO
return random.uniform(0, 1)

def ctc_net(input, size_vocab):
size = len(vocab)
# prob = np.array([random.uniform(0, 1) for i in range(0, size)])
prob = np.array([1.0 for i in range(0, size)])
prob = prob/prob.sum()
return prob

def main():

input_probs_matrix = [[0.1, 0.3, 0.6],
[0.2, 0.1, 0.7],
[0.5, 0.2, 0.3]]

prob_matrix = [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
# Random entry added in at time=5
#[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
]

beam_result = ctc_beam_search_decoder(
input_probs_matrix=prob_matrix,
beam_size=2,
blank_id=5,
)
def ids2str(ids_list):
ids_str = ''
for ids in ids_list:
ids_str += vocab[ids]
return ids_str

print "\nbeam search output:"
for result in beam_result:
print result[0], ids2str(result[1])

if __name__ == '__main__':
main()

0 comments on commit 504b15c

Please sign in to comment.