From cc24646a876773e2b28de3b76a2dafbcefdbc729 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 9 Sep 2021 23:51:48 +0800 Subject: [PATCH] fix rec distillation (#3995) * fix rec distillation * add dist cfg * fix cfg --- configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml | 88 +++------- .../ch_PP-OCRv2_rec_distillation.yml | 160 ++++++++++++++++++ doc/doc_ch/knowledge_distillation.md | 35 +++- ppocr/losses/basic_loss.py | 29 ++-- ppocr/losses/combined_loss.py | 14 +- ppocr/losses/distillation_loss.py | 14 +- ppocr/utils/save_load.py | 10 +- 7 files changed, 255 insertions(+), 95 deletions(-) create mode 100644 configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml index 27ba4fd70b..38f77f7372 100644 --- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml @@ -4,7 +4,7 @@ Global: epoch_num: 800 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 + save_model_dir: ./output/rec_mobile_pp-OCRv2 save_epoch_step: 3 eval_batch_step: [0, 2000] cal_metric_during_train: true @@ -19,7 +19,7 @@ Global: infer_mode: false use_space_char: true distributed: true - save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt + save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt Optimizer: @@ -35,79 +35,32 @@ Optimizer: name: L2 factor: 2.0e-05 + Architecture: - model_type: &model_type "rec" - name: DistillationModel - algorithm: Distillation - Models: - Teacher: - pretrained: - freeze_params: false - return_all_feats: true - model_type: *model_type - algorithm: CRNN - Transform: - Backbone: - name: MobileNetV1Enhance - scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 - Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 - Student: - pretrained: - freeze_params: false - return_all_feats: true - model_type: *model_type - algorithm: CRNN - Transform: - Backbone: - name: MobileNetV1Enhance - scale: 0.5 - Neck: - name: SequenceEncoder - encoder_type: rnn - hidden_size: 64 - Head: - name: CTCHead - mid_channels: 96 - fc_decay: 0.00002 - + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 Loss: - name: CombinedLoss - loss_config_list: - - DistillationCTCLoss: - weight: 1.0 - model_name_list: ["Student", "Teacher"] - key: head_out - - DistillationDMLLoss: - weight: 1.0 - act: "softmax" - model_name_pairs: - - ["Student", "Teacher"] - key: head_out - - DistillationDistanceLoss: - weight: 1.0 - mode: "l2" - model_name_pairs: - - ["Student", "Teacher"] - key: backbone_out + name: CTCLoss PostProcess: - name: DistillationCTCLabelDecode - model_name: ["Student", "Teacher"] - key: head_out + name: CTCLabelDecode Metric: - name: DistillationMetric - base_metric_name: RecMetric + name: RecMetric main_indicator: acc - key: "Student" Train: dataset: @@ -132,7 +85,6 @@ Train: shuffle: true batch_size_per_card: 128 drop_last: true - num_sections: 1 num_workers: 8 Eval: dataset: diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml new file mode 100644 index 0000000000..d2308fd574 --- /dev/null +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml @@ -0,0 +1,160 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_pp-OCRv2_distillation + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs : [700, 800] + values : [0.001, 0.0001] + warmup_epoch: 5 + regularizer: + name: L2 + factor: 2.0e-05 + +Architecture: + model_type: &model_type "rec" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + + +Loss: + name: CombinedLoss + loss_config_list: + - DistillationCTCLoss: + weight: 1.0 + model_name_list: ["Student", "Teacher"] + key: head_out + - DistillationDMLLoss: + weight: 1.0 + act: "softmax" + use_log: true + model_name_pairs: + - ["Student", "Teacher"] + key: head_out + - DistillationDistanceLoss: + weight: 1.0 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + +PostProcess: + name: DistillationCTCLabelDecode + model_name: ["Student", "Teacher"] + key: head_out + +Metric: + name: DistillationMetric + base_metric_name: RecMetric + main_indicator: acc + key: "Student" + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_sections: 1 + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md index 5827f48c81..b2772454d9 100644 --- a/doc/doc_ch/knowledge_distillation.md +++ b/doc/doc_ch/knowledge_distillation.md @@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要 ### 2.1 识别配置文件解析 -配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)。 +配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。 #### 2.1.1 模型结构 @@ -246,6 +246,39 @@ Metric: 关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。 +#### 2.1.5 蒸馏模型微调 + +对蒸馏得到的识别蒸馏进行微调有2种方式。 + +(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。 + +(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。 + +* 首先下载预训练模型并解压。 +```shell +# 下面预训练模型并解压 +wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar +tar -xf ch_PP-OCRv2_rec_train.tar +``` + +* 然后使用python,对其中的学生模型参数进行提取 + +```python +import paddle +# 加载预训练模型 +all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams") +# 查看权重参数的keys +print(all_params.keys()) +# 学生模型的权重提取 +s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} +# 查看学生模型权重参数的keys +print(s_params.keys()) +# 保存 +paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams") +``` + +转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。 + ### 2.2 检测配置文件解析 * coming soon! diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 8306523ac1..d2ef5e5ac9 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -56,31 +56,34 @@ def forward(self, x, label): class KLJSLoss(object): def __init__(self, mode='kl'): - assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']" + assert mode in ['kl', 'js', 'KL', 'JS' + ], "mode can only be one of ['kl', 'js', 'KL', 'JS']" self.mode = mode def __call__(self, p1, p2, reduction="mean"): - loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5)) + loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) if self.mode.lower() == "js": - loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5)) + loss += paddle.multiply( + p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) loss *= 0.5 if reduction == "mean": - loss = paddle.mean(loss, axis=[1,2]) - elif reduction=="none" or reduction is None: - return loss + loss = paddle.mean(loss, axis=[1, 2]) + elif reduction == "none" or reduction is None: + return loss else: - loss = paddle.sum(loss, axis=[1,2]) + loss = paddle.sum(loss, axis=[1, 2]) + + return loss - return loss class DMLLoss(nn.Layer): """ DMLLoss """ - def __init__(self, act=None): + def __init__(self, act=None, use_log=False): super().__init__() if act is not None: assert act in ["softmax", "sigmoid"] @@ -90,20 +93,24 @@ def __init__(self, act=None): self.act = nn.Sigmoid() else: self.act = None - + + self.use_log = use_log + self.jskl_loss = KLJSLoss(mode="js") def forward(self, out1, out2): if self.act is not None: out1 = self.act(out1) out2 = self.act(out2) - if len(out1.shape) < 2: + if self.use_log: + # for recognition distillation, log is needed for feature map log_out1 = paddle.log(out1) log_out2 = paddle.log(out2) loss = (F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div( log_out2, out1, reduction='batchmean')) / 2.0 else: + # for detection distillation log is not needed loss = self.jskl_loss(out1, out2) return loss diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 0d6fe968d0..f3bb36cf5a 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -49,11 +49,15 @@ def forward(self, input, batch, **kargs): loss = loss_func(input, batch, **kargs) if isinstance(loss, paddle.Tensor): loss = {"loss_{}_{}".format(str(loss), idx): loss} + weight = self.loss_weight[idx] - for key in loss.keys(): - if key == "loss": - loss_all += loss[key] * weight - else: - loss_dict["{}_{}".format(key, idx)] = loss[key] + + loss = {key: loss[key] * weight for key in loss} + + if "loss" in loss: + loss_all += loss["loss"] + else: + loss_all += paddle.add_n(list(loss.values())) + loss_dict.update(loss) loss_dict["loss"] = loss_all return loss_dict diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 75f0a77315..73d3ae2ad2 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -44,20 +44,22 @@ class DistillationDMLLoss(DMLLoss): def __init__(self, model_name_pairs=[], act=None, + use_log=False, key=None, maps_name=None, name="dml"): - super().__init__(act=act) + super().__init__(act=act, use_log=use_log) assert isinstance(model_name_pairs, list) self.key = key self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.name = name self.maps_name = self._check_maps_name(maps_name) - + def _check_model_name_pairs(self, model_name_pairs): if not isinstance(model_name_pairs, list): return [] - elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str): + elif isinstance(model_name_pairs[0], list) and isinstance( + model_name_pairs[0][0], str): return model_name_pairs else: return [model_name_pairs] @@ -112,9 +114,9 @@ def forward(self, predicts, batch): loss_dict["{}_{}_{}_{}_{}".format(key, pair[ 0], pair[1], map_name, idx)] = loss[key] else: - loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], - idx)] = loss - + loss_dict["{}_{}_{}".format(self.name, self.maps_name[ + _c], idx)] = loss + loss_dict = _sum_loss(loss_dict) return loss_dict diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3bb022ed98..a7d24dd71a 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer): for k1, k2 in zip(state_dict.keys(), params.keys()): if list(state_dict[k1].shape) == list(params[k2].shape): new_state_dict[k1] = params[k2] - else: - logger.info( - f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" - ) + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) model.set_state_dict(new_state_dict) logger.info(f"loaded pretrained_model successful from {pm}") return {} + def load_pretrained_params(model, path): if path is None: return False @@ -138,6 +139,7 @@ def load_pretrained_params(model, path): print(f"load pretrain successful from {path}") return model + def save_model(model, optimizer, model_path,