Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:release/2.4' into release/2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
livingbody authored Jan 14, 2022
2 parents beae94c + 22b1fb3 commit acd7772
Show file tree
Hide file tree
Showing 13 changed files with 792 additions and 251 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class Predictor {
protected String inputColorFormat = "BGR";
protected long[] inputShape = new long[]{1, 3, 960};
protected float[] inputMean = new float[]{0.485f, 0.456f, 0.406f};
protected float[] inputStd = new float[]{1.0f / 0.229f, 1.0f / 0.224f, 1.0f / 0.225f};
protected float[] inputStd = new float[]{0.229f, 0.224f, 0.225f};
protected float scoreThreshold = 0.1f;
protected Bitmap inputImage = null;
protected Bitmap outputImage = null;
Expand Down
4 changes: 2 additions & 2 deletions doc/doc_ch/algorithm_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ PaddleOCR开源的文本检测算法列表:
在ICDAR2015文本检测公开数据集上,算法效果如下:
|模型|骨干网络|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|EAST|ResNet50_vd|88.71%|81.36%|84.88%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|78.2%|79.1%|78.65%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
Expand Down
11 changes: 8 additions & 3 deletions doc/doc_ch/customize.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
## step1:训练文本检测模型

PaddleOCR提供了EAST、DB两种文本检测算法,均支持MobileNetV3、ResNet50_vd两种骨干网络,根据需要选择相应的配置文件,启动训练。例如,训练使用MobileNetV3作为骨干网络的DB检测模型(即超轻量模型使用的配置):
```

```sh
python3 tools/train.py -c configs/det/det_mv3_db.yml 2>&1 | tee det_db.log
```

更详细的数据准备和训练教程参考文档教程中[文本检测模型训练/评估/预测](./detection.md)

## step2:训练文本识别模型

PaddleOCR提供了CRNN、Rosetta、STAR-Net、RARE四种文本识别算法,均支持MobileNetV3、ResNet34_vd两种骨干网络,根据需要选择相应的配置文件,启动训练。例如,训练使用MobileNetV3作为骨干网络的CRNN识别模型(即超轻量模型使用的配置):
```

```sh
python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml 2>&1 | tee rec_ch_lite.log
```

更详细的数据准备和训练教程参考文档教程中[文本识别模型训练/评估/预测](./recognition.md)

## step3:模型串联预测
Expand All @@ -24,7 +28,8 @@ PaddleOCR提供了检测和识别模型的串联工具,可以将训练好的

在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。

```
```sh
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
```

更多的文本检测、识别串联推理使用方式请参考文档教程中的[基于预测引擎推理](./inference.md)
53 changes: 39 additions & 14 deletions doc/doc_ch/knowledge_distillation.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
<a name="0"></a>
# 知识蒸馏


+ [知识蒸馏](#0)
+ [1. 简介](#1)
- [1.1 知识蒸馏介绍](#11)
- [1.2 PaddleOCR知识蒸馏简介](#12)
+ [2. 配置文件解析](#2)
+ [2.1 识别配置文件解析](#21)
- [2.1.1 模型结构](#211)
- [2.1.2 损失函数](#212)
- [2.1.3 后处理](#213)
- [2.1.4 指标计算](#214)
- [2.1.5 蒸馏模型微调](#215)
+ [2.2 检测配置文件解析](#22)
- [2.2.1 模型结构](#221)
- [2.2.2 损失函数](#222)
- [2.2.3 后处理](#223)
- [2.2.4 蒸馏指标计算](#224)
- [2.2.5 检测蒸馏模型Fine-tune](#225)

<a name="1"></a>
## 1. 简介

<a name="11"></a>
### 1.1 知识蒸馏介绍

近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
Expand All @@ -13,6 +32,7 @@

此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。

<a name="12"></a>
### 1.2 PaddleOCR知识蒸馏简介

无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
Expand All @@ -30,17 +50,19 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。



<a name="2"></a>
## 2. 配置文件解析

在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。

下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。

<a name="21"></a>
### 2.1 识别配置文件解析

配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)

<a name="211"></a>
#### 2.1.1 模型结构

知识蒸馏任务中,模型结构配置如下所示。
Expand Down Expand Up @@ -176,6 +198,7 @@ Architecture:
}
```

<a name="212"></a>
#### 2.1.2 损失函数

知识蒸馏任务中,损失函数配置如下所示。
Expand Down Expand Up @@ -212,7 +235,7 @@ Loss:

关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。


<a name="213"></a>
#### 2.1.3 后处理

知识蒸馏任务中,后处理配置如下所示。
Expand All @@ -228,7 +251,7 @@ PostProcess:

关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)


<a name="214"></a>
#### 2.1.4 指标计算

知识蒸馏任务中,指标计算配置如下所示。
Expand All @@ -245,7 +268,7 @@ Metric:

关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。


<a name="215"></a>
#### 2.1.5 蒸馏模型微调

对蒸馏得到的识别蒸馏进行微调有2种方式。
Expand Down Expand Up @@ -279,15 +302,15 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")

转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。


<a name="22"></a>
### 2.2 检测配置文件解析

检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法


<a name="221"></a>
#### 2.2.1 模型结构

知识蒸馏任务中,模型结构配置如下所示:
Expand Down Expand Up @@ -419,7 +442,8 @@ Architecture:
}
```

#### 2.1.2 损失函数
<a name="222"></a>
#### 2.2.2 损失函数

知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。

Expand Down Expand Up @@ -484,8 +508,8 @@ Loss:

关于`DistillationDilaDBLoss`更加具体的实现可以参考: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185)。关于`DistillationDBLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148)


#### 2.1.3 后处理
<a name="223"></a>
#### 2.2.3 后处理

知识蒸馏任务中,检测蒸馏后处理配置如下所示。

Expand All @@ -504,7 +528,8 @@ PostProcess:
关于`DistillationDBPostProcess`更加具体的实现可以参考: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195)


#### 2.1.4 蒸馏指标计算
<a name="224"></a>
#### 2.2.4 蒸馏指标计算

知识蒸馏任务中,检测蒸馏指标计算配置如下所示。

Expand All @@ -518,8 +543,8 @@ Metric:

由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,`key`字段设置为`Student`则表示只计算`Student`网络的精度。


#### 2.1.5 检测蒸馏模型finetune
<a name="225"></a>
#### 2.2.5 检测蒸馏模型finetune

检测蒸馏有三种方式:
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
Expand Down
4 changes: 2 additions & 2 deletions doc/doc_en/algorithm_overview_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:

|Model|Backbone|Precision|Recall|Hmean|Download link|
| --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|EAST|ResNet50_vd|88.71%|81.36%|84.88%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|78.2%|79.1%|78.65%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
Expand Down
Loading

0 comments on commit acd7772

Please sign in to comment.