-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathctc_decoder.py
69 lines (61 loc) · 2.61 KB
/
ctc_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File :rnn-transducer -> ctc_decoder
@IDE :PyCharm
@Author :chengli
@Date :2020/11/12 6:48 PM
@Desc :
=================================================='''
import numpy as np
def softmax(logits):
# 注意这里求e的次方时,次方数减去max_value其实不影响结果,因为最后可以化简成教科书上softmax的定义
# 次方数加入减max_value是因为e的x次方与x的极限(x趋于无穷)为无穷,很容易溢出,所以为了计算时不溢出,就加入减max_value项
# 次方数减去max_value后,e的该次方数总是在0到1范围内。
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def remove_blank(labels, blank=0):
new_labels = []
# 合并相同的标签
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# 删除blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def beam_decode(y, beam_size=10):
# y是个二维数组,记录了所有时刻的所有项的概率
T, V = y.shape
# 将所有的y中值改为log是为了防止溢出,因为最后得到的p是y1..yn连乘,且yi都在0到1之间,可能会导致下溢出
# 改成log(y)以后就变成连加了,这样就防止了下溢出
log_y = np.log(y)
# 初始的beam
beam = [([], 0)]
# 遍历所有时刻t
for t in range(T):
# 每个时刻先初始化一个new_beam
new_beam = []
# 遍历beam
for prefix, score in beam:
# 对于一个时刻中的每一项(一共V项)
for i in range(V):
# 记录添加的新项是这个时刻的第几项,对应的概率(log形式的)加上新的这项log形式的概率(本来是乘的,改成log就是加)
new_prefix = prefix + [i]
new_score = score + log_y[t, i]
# new_beam记录了对于beam中某一项,将这个项分别加上新的时刻中的每一项后的概率
new_beam.append((new_prefix, new_score))
# 给new_beam按score排序
new_beam.sort(key=lambda x: x[1], reverse=True)
# beam即为new_beam中概率最大的beam_size个路径
beam = new_beam[:beam_size]
return beam