Skip to content

Commit

Permalink
Merge pull request #5118 from WZMIAOMIAO/wzmiaomiao
Browse files Browse the repository at this point in the history
translate  knowledge_distillation_en.md
  • Loading branch information
Evezerest authored Jan 6, 2022
2 parents dbe527e + 9e7601d commit 56f426f
Show file tree
Hide file tree
Showing 2 changed files with 631 additions and 14 deletions.
53 changes: 39 additions & 14 deletions doc/doc_ch/knowledge_distillation.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
<a name="0"></a>
# 知识蒸馏


+ [知识蒸馏](#0)
+ [1. 简介](#1)
- [1.1 知识蒸馏介绍](#11)
- [1.2 PaddleOCR知识蒸馏简介](#12)
+ [2. 配置文件解析](#2)
+ [2.1 识别配置文件解析](#21)
- [2.1.1 模型结构](#211)
- [2.1.2 损失函数](#212)
- [2.1.3 后处理](#213)
- [2.1.4 指标计算](#214)
- [2.1.5 蒸馏模型微调](#215)
+ [2.2 检测配置文件解析](#22)
- [2.2.1 模型结构](#221)
- [2.2.2 损失函数](#222)
- [2.2.3 后处理](#223)
- [2.2.4 蒸馏指标计算](#224)
- [2.2.5 检测蒸馏模型Fine-tune](#225)

<a name="1"></a>
## 1. 简介

<a name="11"></a>
### 1.1 知识蒸馏介绍

近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
Expand All @@ -13,6 +32,7 @@

此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。

<a name="12"></a>
### 1.2 PaddleOCR知识蒸馏简介

无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
Expand All @@ -30,17 +50,19 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。



<a name="2"></a>
## 2. 配置文件解析

在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。

下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。

<a name="21"></a>
### 2.1 识别配置文件解析

配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)

<a name="211"></a>
#### 2.1.1 模型结构

知识蒸馏任务中,模型结构配置如下所示。
Expand Down Expand Up @@ -176,6 +198,7 @@ Architecture:
}
```

<a name="212"></a>
#### 2.1.2 损失函数

知识蒸馏任务中,损失函数配置如下所示。
Expand Down Expand Up @@ -212,7 +235,7 @@ Loss:

关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。


<a name="213"></a>
#### 2.1.3 后处理

知识蒸馏任务中,后处理配置如下所示。
Expand All @@ -228,7 +251,7 @@ PostProcess:

关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)


<a name="214"></a>
#### 2.1.4 指标计算

知识蒸馏任务中,指标计算配置如下所示。
Expand All @@ -245,7 +268,7 @@ Metric:

关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。


<a name="215"></a>
#### 2.1.5 蒸馏模型微调

对蒸馏得到的识别蒸馏进行微调有2种方式。
Expand Down Expand Up @@ -279,15 +302,15 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")

转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。


<a name="22"></a>
### 2.2 检测配置文件解析

检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法


<a name="221"></a>
#### 2.2.1 模型结构

知识蒸馏任务中,模型结构配置如下所示:
Expand Down Expand Up @@ -419,7 +442,8 @@ Architecture:
}
```

#### 2.1.2 损失函数
<a name="222"></a>
#### 2.2.2 损失函数

知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。

Expand Down Expand Up @@ -484,8 +508,8 @@ Loss:

关于`DistillationDilaDBLoss`更加具体的实现可以参考: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185)。关于`DistillationDBLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148)


#### 2.1.3 后处理
<a name="223"></a>
#### 2.2.3 后处理

知识蒸馏任务中,检测蒸馏后处理配置如下所示。

Expand All @@ -504,7 +528,8 @@ PostProcess:
关于`DistillationDBPostProcess`更加具体的实现可以参考: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195)


#### 2.1.4 蒸馏指标计算
<a name="224"></a>
#### 2.2.4 蒸馏指标计算

知识蒸馏任务中,检测蒸馏指标计算配置如下所示。

Expand All @@ -518,8 +543,8 @@ Metric:

由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,`key`字段设置为`Student`则表示只计算`Student`网络的精度。


#### 2.1.5 检测蒸馏模型finetune
<a name="225"></a>
#### 2.2.5 检测蒸馏模型finetune

检测蒸馏有三种方式:
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
Expand Down
Loading

0 comments on commit 56f426f

Please sign in to comment.