Skip to content

Commit

Permalink
fix rec distillation (PaddlePaddle#3995)
Browse files Browse the repository at this point in the history
* fix rec distillation

* add dist cfg

* fix cfg
  • Loading branch information
littletomatodonkey authored Sep 9, 2021
1 parent 3805cc2 commit cc24646
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 95 deletions.
88 changes: 20 additions & 68 deletions configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Global:
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
save_model_dir: ./output/rec_mobile_pp-OCRv2
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
Expand All @@ -19,7 +19,7 @@ Global:
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt


Optimizer:
Expand All @@ -35,79 +35,32 @@ Optimizer:
name: L2
factor: 2.0e-05


Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002

model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002

Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
name: CTCLoss

PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
name: CTCLabelDecode

Metric:
name: DistillationMetric
base_metric_name: RecMetric
name: RecMetric
main_indicator: acc
key: "Student"

Train:
dataset:
Expand All @@ -132,7 +85,6 @@ Train:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
Expand Down
160 changes: 160 additions & 0 deletions configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_pp-OCRv2_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt


Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5
regularizer:
name: L2
factor: 2.0e-05

Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002


Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out

PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out

Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"

Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
35 changes: 34 additions & 1 deletion doc/doc_ch/knowledge_distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要

### 2.1 识别配置文件解析

配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)
配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)

#### 2.1.1 模型结构

Expand Down Expand Up @@ -246,6 +246,39 @@ Metric:
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。


#### 2.1.5 蒸馏模型微调

对蒸馏得到的识别蒸馏进行微调有2种方式。

(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。

(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。

* 首先下载预训练模型并解压。
```shell
# 下面预训练模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
tar -xf ch_PP-OCRv2_rec_train.tar
```

* 然后使用python,对其中的学生模型参数进行提取

```python
import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 学生模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看学生模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
```

转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。

### 2.2 检测配置文件解析

* coming soon!
Loading

0 comments on commit cc24646

Please sign in to comment.