From 27b6346c955e546422ab17b4d545b5eefdbfe2b3 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 29 Sep 2021 10:30:09 +0800 Subject: [PATCH] add center loss cod and cfg (#4165) * add center loss cod and cfg * fix name --- .../ch_PP-OCRv2_rec_enhanced_ctc_loss.yml | 126 ++++++++++++++++++ ppocr/data/imaug/label_ops.py | 5 + ppocr/losses/__init__.py | 1 - ppocr/losses/ace_loss.py | 50 +++++++ ppocr/losses/center_loss.py | 89 +++++++++++++ ppocr/losses/combined_loss.py | 4 + ppocr/losses/rec_ctc_loss.py | 10 +- ppocr/modeling/heads/rec_ctc_head.py | 17 ++- ppocr/postprocess/rec_postprocess.py | 2 + 9 files changed, 298 insertions(+), 6 deletions(-) create mode 100644 configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml create mode 100644 ppocr/losses/ace_loss.py create mode 100644 ppocr/losses/center_loss.py diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml new file mode 100644 index 0000000000..8b568637a1 --- /dev/null +++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml @@ -0,0 +1,126 @@ +Global: + debug: false + use_gpu: true + epoch_num: 800 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss + save_epoch_step: 3 + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: ch + max_text_length: 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Piecewise + decay_epochs : [700, 800] + values : [0.001, 0.0001] + warmup_epoch: 5 + regularizer: + name: L2 + factor: 2.0e-05 + + +Architecture: + model_type: rec + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + return_feats: true + +Loss: + name: CombinedLoss + loss_config_list: + - CTCLoss: + use_focal_loss: false + weight: 1.0 + - CenterLoss: + weight: 0.05 + num_classes: 6625 + feat_dim: 96 + init_center: false + center_file_path: "./train_center.pkl" + # you can also try to add ace loss on your own dataset + # - ACELoss: + # weight: 0.1 + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + - label_ace + loader: + shuffle: true + batch_size_per_card: 128 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - CTCLabelEncode: + - RecResizeImg: + image_shape: [3, 32, 320] + - KeepKeys: + keep_keys: + - image + - label + - length + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 8 diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index f761eaf669..ebf52ec4e1 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -215,6 +215,11 @@ def __call__(self, data): data['length'] = np.array(len(text)) text = text + [0] * (self.max_text_len - len(text)) data['label'] = np.array(text) + + label = [0] * len(self.character) + for x in text: + label[x] += 1 + data['label_ace'] = np.array(label) return data def add_special_char(self, dict_character): diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index a6c2a9f6d1..f3f4cd4933 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -52,7 +52,6 @@ def build_loss(config): 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss' ] - config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py new file mode 100644 index 0000000000..9c868520e5 --- /dev/null +++ b/ppocr/losses/ace_loss.py @@ -0,0 +1,50 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class ACELoss(nn.Layer): + def __init__(self, **kwargs): + super().__init__() + self.loss_func = nn.CrossEntropyLoss( + weight=None, + ignore_index=0, + reduction='none', + soft_label=True, + axis=-1) + + def __call__(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] + B, N = predicts.shape[:2] + div = paddle.to_tensor([N]).astype('float32') + + predicts = nn.functional.softmax(predicts, axis=-1) + aggregation_preds = paddle.sum(predicts, axis=1) + aggregation_preds = paddle.divide(aggregation_preds, div) + + length = batch[2].astype("float32") + batch = batch[3].astype("float32") + batch[:, 0] = paddle.subtract(div, length) + + batch = paddle.divide(batch, div) + + loss = self.loss_func(aggregation_preds, batch) + + return {"loss_ace": loss} diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py new file mode 100644 index 0000000000..72149df19f --- /dev/null +++ b/ppocr/losses/center_loss.py @@ -0,0 +1,89 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import pickle + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class CenterLoss(nn.Layer): + """ + Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. + """ + + def __init__(self, + num_classes=6625, + feat_dim=96, + init_center=False, + center_file_path=None): + super().__init__() + self.num_classes = num_classes + self.feat_dim = feat_dim + self.centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]).astype( + "float64") #random center + + if init_center: + assert os.path.exists( + center_file_path + ), f"center path({center_file_path}) must exist when init_center is set as True." + with open(center_file_path, 'rb') as f: + char_dict = pickle.load(f) + for key in char_dict.keys(): + self.centers[key] = paddle.to_tensor(char_dict[key]) + + def __call__(self, predicts, batch): + assert isinstance(predicts, (list, tuple)) + features, predicts = predicts + + feats_reshape = paddle.reshape( + features, [-1, features.shape[-1]]).astype("float64") + label = paddle.argmax(predicts, axis=2) + label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) + + batch_size = feats_reshape.shape[0] + + #calc feat * feat + dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True) + dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) + + #dist2 of centers + dist2 = paddle.sum(paddle.square(self.centers), axis=1, + keepdim=True) #num_classes + dist2 = paddle.expand(dist2, + [self.num_classes, batch_size]).astype("float64") + dist2 = paddle.transpose(dist2, [1, 0]) + + #first x * x + y * y + distmat = paddle.add(dist1, dist2) + tmp = paddle.matmul(feats_reshape, + paddle.transpose(self.centers, [1, 0])) + distmat = distmat - 2.0 * tmp + + #generate the mask + classes = paddle.arange(self.num_classes).astype("int64") + label = paddle.expand( + paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) + mask = paddle.equal( + paddle.expand(classes, [batch_size, self.num_classes]), + label).astype("float64") #get mask + dist = paddle.multiply(distmat, mask) + loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + return {'loss_center': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index f3bb36cf5a..72f706e37d 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -15,6 +15,10 @@ import paddle import paddle.nn as nn +from .rec_ctc_loss import CTCLoss +from .center_loss import CenterLoss +from .ace_loss import ACELoss + from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 6c0b56ff84..5d09802b46 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -21,16 +21,24 @@ class CTCLoss(nn.Layer): - def __init__(self, **kwargs): + def __init__(self, use_focal_loss=False, **kwargs): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') + self.use_focal_loss = use_focal_loss def forward(self, predicts, batch): + if isinstance(predicts, (list, tuple)): + predicts = predicts[-1] predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') labels = batch[1].astype("int32") label_lengths = batch[2].astype('int64') loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) + if self.use_focal_loss: + weight = paddle.exp(-loss) + weight = paddle.subtract(paddle.to_tensor([1.0]), weight) + weight = paddle.square(weight) * self.focal_loss_alpha + loss = paddle.multiply(loss, weight) loss = loss.mean() # sum return {'loss': loss} diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 9c38d31fa0..35d33d5f56 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -38,6 +38,7 @@ def __init__(self, out_channels, fc_decay=0.0004, mid_channels=None, + return_feats=False, **kwargs): super(CTCHead, self).__init__() if mid_channels is None: @@ -66,14 +67,22 @@ def __init__(self, bias_attr=bias_attr2) self.out_channels = out_channels self.mid_channels = mid_channels + self.return_feats = return_feats def forward(self, x, targets=None): if self.mid_channels is None: predicts = self.fc(x) else: - predicts = self.fc1(x) - predicts = self.fc2(predicts) - + x = self.fc1(x) + predicts = self.fc2(x) + + if self.return_feats: + result = (x, predicts) + else: + result = predicts + if not self.training: predicts = F.softmax(predicts, axis=2) - return predicts + result = predicts + + return result diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 16f7f76596..c06159ca55 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -111,6 +111,8 @@ def __init__(self, character_type, use_space_char) def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple): + preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2)