-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update ctc for application of STR #253
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先 comment 一部分,请务必保证行为在技术上的精确。
@@ -0,0 +1,475 @@ | |||
# CTC (Connectionist Temporal Classification) 模型CRNN教程 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 请把标题改为:场景文字识别 (STR, Scene Text Recognition)
# CTC (Connectionist Temporal Classification) 模型CRNN教程 | ||
## 背景简介 | ||
|
||
现实世界中的序列学习任务需要从输入序列中预测出对应标签序列, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句话不合适,“现实世界中的序列学习任务“种类繁多,并不都是”需要从输入序列中预测出对应标签序列“。
|
||
现实世界中的序列学习任务需要从输入序列中预测出对应标签序列, | ||
比如语音识别任务从连续的语音中得到对应文字序列; | ||
CTC相关模型就是实现此类任务的的一类算法,具体地,CTC模型为输入序列中每个时间步做一次分类输出一个标签(CTC中 Classification的来源), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- CTC 不是算法,是为序列标注任务($m$个标记$n$个,$m \ne n$ )设计的一种端到端损失函数。
- CTC is a loss function.
|
||
1. `layer.warp_ctc` 调用 `warp CTC` 实现 CTC layer. | ||
2. 使用多层 `conv_group` 快速构架深度 CNN | ||
3. 使用 `layer.block_expand` 将矩阵切割成固定大小的块序列 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请注意每句话的结束都应该有标点符号。并且请使用中文句号,而不是中英文混用。
最终对输出的标签序列处理成对应的输出序列(具体算法参见下文)。 | ||
|
||
CTC 算法在很多领域中有应用,比如手写数字识别、语音识别、手势识别、图像中的文字识别等,除去不同任务中的专业知识不同, | ||
所有任务均为连续序列输入,标签序列输出。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 请花一些时间查阅资料,重写关于 CTC 的背景,在技术上不精确,不严谨。
- 不追求长段,但必须保证技术精确,以前的一个非常简短(过于简单)的笔记,仅供参考。
https://github.com/lcy-seso/deeplearning-papernotes/blob/master/ctc_loss/CTC_loss_function.md
0.3 * 0.21 * 0.22 = 0.01386 | ||
$$ | ||
|
||
CRNN模型中继承了\[[3](#参考文献)\]中的`CTC layer`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么叫继承了?有点奇怪,请修改。
$$ | ||
|
||
CRNN模型中继承了\[[3](#参考文献)\]中的`CTC layer`, | ||
不同于经典NMT(Neural Machine Translation)中使用的beam search算法\[[7](#参考文献)\],CTC layer不会考虑已经生成的标签上文的信息,只考虑当前时间步生成某个标签的概率。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里不需要提 NMT 和 BeamSearch,请删掉。
- "CTC layer不会考虑已经生成的标签上文的信息,只考虑当前时间步生成某个标签的概率。" 请斟酌一些这句话的主语。
|
||
其中,$l$表示目标序列,$y$ 是标签分布的序列,$\pi$ 表示将预测出的序列分布转化为目标标签序列的映射。 | ||
|
||
### 训练和预测原理 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练和预测
比如去空格和重复字符,比如预测到 $l*$ 为`-h-el-ll-o` ,对应转化为 `hello` 作为最终输出序列。 | ||
|
||
至此,CRNN模型的原理基本介绍完毕,输入原始的图片数据,CRNN会利用CNN来学习提取图像特征,转化为特征向量的序列,交由RNN学习; | ||
RNN会在为时间步生成标签的概率分布,所有标签分布会交由CTC layer获得生成目标序列的所有映射的概率求和,作为模型生成目标序列的预测概率(学习损失)。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- “RNN会在为时间步生成标签的概率分布,” 这句话不对。这个概率分布不是 RNN 生成的。
- 请注意不要使用 CTC Layer。
至此,CRNN模型的原理基本介绍完毕,输入原始的图片数据,CRNN会利用CNN来学习提取图像特征,转化为特征向量的序列,交由RNN学习; | ||
RNN会在为时间步生成标签的概率分布,所有标签分布会交由CTC layer获得生成目标序列的所有映射的概率求和,作为模型生成目标序列的预测概率(学习损失)。 | ||
|
||
## 用 PaddlePaddle 实现模型算法 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 PaddlePaddle 训练与预测
## STR任务简介 | ||
|
||
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息, | ||
Google 已经用AI算法自动识别路牌中的文字来获取街景更准确的地址\[[2](#参考文献)\] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[2] 使用深度学习模型自动识别路牌中的文字,帮助街景应用获取更加准确的地址信息。
在现实生活中,包括路牌、菜单、大厦标语在内的很多场景均会有文字出现,这些场景的照片中的文字为图片场景的理解提供了更多信息, | ||
Google 已经用AI算法自动识别路牌中的文字来获取街景更准确的地址\[[2](#参考文献)\] | ||
|
||
在本教程中,我们使用的训练数据类似如下图片,需要识别为 *"keep"* 。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本教程使用如下图片进行训练,需要识别文字对应的文字 "keep"
|
||
通过查找上表,很容易得到模型生成任何标签序列的概率,比如生成3个字符 "hel" 的概率的计算如下: | ||
|
||
$$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要用 Latex 标记,请使用文本块标记。
CRNN模型中继承了\[[3](#参考文献)\]中的`CTC layer`, | ||
不同于经典NMT(Neural Machine Translation)中使用的beam search算法\[[7](#参考文献)\],CTC layer不会考虑已经生成的标签上文的信息,只考虑当前时间步生成某个标签的概率。 | ||
|
||
对应着标签的概率分布,会有多种映射从标签分布转化成目标序列,比如一个10个帧的输入特征序列要生成目标序列 "hello",可以有如下映射方式(`-`表示空格): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句话过于口语化。
- `-h-el-ll-o` | ||
- `hello----` | ||
- `-h-e-l-lo` | ||
- 其他若干 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他若干 --> 改为省略号
act=Linear()) | ||
``` | ||
|
||
上面利用了一个 `fc` 全连接层,注意其输入时 `input=[gru_forward, gru_backward]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一段重写,不精确。fc 不是进行拼接。
之后利用 `fc` 映射为维度 `self.num_classes + 1` 的向量(多出来的1表示空格), | ||
多个时间步会构成一个序列。 | ||
|
||
接下来就是输入给 `CTC layer`,这里我们使用了 对应warp CTC\[[5](#参考文献)\] 的封装 `layer.warp_ctc` : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
太口语化,请重写。
|
||
具体的参数包括,传入前面 `fc` 生成的标签分布的向量序列 `self.output` 以及目标标签序列 `self.label` ,标签字典的大小 `self.num_classes+1` , 按时间步归一设 `True` ,空格对应的类别 ID 为 `self.num_classes` 。 | ||
|
||
至此模型的配置基本完毕,接下来介绍训练的配置: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 不要用“基本完毕”这样模棱两可的描述。“完毕”就是完毕,那里省略不赘述,如果会影响理解,那就明确的写出来。不要模棱两可。
4. 训练过程中,模型参数会自动备份到指定目录,默认为 ./model.ctc | ||
5. 设置infer.py中的相关参数,运行```python infer.py``` 进行预测 | ||
|
||
## 写在最后 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 删掉451, 459, 全部合并为:注意事项
5. 设置infer.py中的相关参数,运行```python infer.py``` 进行预测 | ||
|
||
## 写在最后 | ||
### 有用的数据集 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其它数据集
This is a follow-up work for #63