diff --git a/README.md b/README.md index a57365e0a6..0db6fc543b 100755 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ## 📖 技术交流合作 - 飞桨AI套件([PaddleX](http://10.136.157.23:8080/paddle/paddleX))提供了飞桨模型训压推一站式全流程高效率开发平台,其使命是助力AI技术快速落地,愿景是使人人成为AI Developer! - - PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向,已内置了36种基础单模型,例如RP-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等;集成了12种实用的产业方案,例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。 + - PaddleX 目前覆盖图像分类、目标检测、图像分割、3D、OCR和时序预测等领域方向,已内置了36种基础单模型,例如RT-DETR、PP-YOLOE、PP-HGNet、PP-LCNet、PP-LiteSeg等;集成了12种实用的产业方案,例如PP-OCRv4、PP-ChatOCR、PP-ShiTu、PP-TS、车载路面垃圾检测、野生动物违禁制品识别等。 - PaddleX 提供了“工具箱”和“开发者”两种AI开发模式。工具箱模式可以无代码调优关键超参,开发者模式可以低代码进行单模型训压推和多模型串联推理,同时支持云端和本地端。 - PaddleX 还支持联创开发,利润分成!目前 PaddleX 正在快速迭代,欢迎广大的个人开发者和企业开发者参与进来,共创繁荣的 AI 技术生态! diff --git a/configs/rec/rec_vit_parseq.yml b/configs/rec/rec_vit_parseq.yml new file mode 100644 index 0000000000..8ba99e6438 --- /dev/null +++ b/configs/rec/rec_vit_parseq.yml @@ -0,0 +1,116 @@ +Global: + use_gpu: True + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/parseq + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 500] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: ppocr/utils/dict/parseq_dict.txt + character_type: en + max_text_length: 25 + num_heads: 8 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_parseq.txt + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: OneCycle + max_lr: 0.0007 + +Architecture: + model_type: rec + algorithm: ParseQ + in_channels: 3 + Transform: + Backbone: + name: ViTParseQ + img_size: [32, 128] + patch_size: [4, 8] + embed_dim: 384 + depth: 12 + num_heads: 6 + mlp_ratio: 4 + in_channels: 3 + Head: + name: ParseQHead + # Architecture + max_text_length: 25 + embed_dim: 384 + dec_num_heads: 12 + dec_mlp_ratio: 4 + dec_depth: 1 + # Training + perm_num: 6 + perm_forward: true + perm_mirrored: true + dropout: 0.1 + # Decoding mode (test) + decode_ar: true + refine_iters: 1 + +Loss: + name: ParseQLoss + +PostProcess: + name: ParseQLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: LMDBDataSet + data_dir: + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ParseQRecAug: + aug_type: 0 # or 1 + - ParseQLabelEncode: + - SVTRRecResizeImg: + image_shape: [3, 32, 128] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 192 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - ParseQLabelEncode: # Class handling label + - SVTRRecResizeImg: + image_shape: [3, 32, 128] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 384 + num_workers: 4 diff --git a/doc/doc_ch/PP-OCRv4_introduction.md b/doc/doc_ch/PP-OCRv4_introduction.md index cf1ac63f86..b3dbe76e8b 100644 --- a/doc/doc_ch/PP-OCRv4_introduction.md +++ b/doc/doc_ch/PP-OCRv4_introduction.md @@ -81,13 +81,13 @@ PP-OCRv4检测模型对PP-OCRv3中的CML(Collaborative Mutual Learning) 协同 ## 3. 识别优化 -PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。直接将PP-OCRv2的识别模型,替换成SVTR_Tiny,识别准确率从74.8%提升到80.1%(+5.3%),但是预测速度慢了将近11倍,CPU上预测一条文本行,将近100ms。因此,如下图所示,PP-OCRv3采用如下6个优化策略进行识别模型加速。 +PP-OCRv4识别模型在PP-OCRv3的基础上进一步升级。如下图所示,整体的框架图保持了与PP-OCRv3识别模型相同的pipeline,分别进行了数据、网络结构、训练策略等方面的优化。
-基于上述策略,PP-OCRv4识别模型相比PP-OCRv3,在速度可比的情况下,精度进一步提升4%。 具体消融实验如下所示: +经过如图所示的策略优化,PP-OCRv4识别模型相比PP-OCRv3,在速度可比的情况下,精度进一步提升4%。 具体消融实验如下所示: | ID | 策略 | 模型大小 | 精度 | 预测耗时(CPU openvino)| |-----|-----|--------|----| --- | @@ -103,8 +103,8 @@ PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2 **(1)DF:数据挖掘方案** -DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据,通过置信度和预测结果等信息,对全量数据进行筛选。具体的:首先使用少量数据快速训练得到一个低精度模型,使用该低精度模型对千万级的数据进行预测,去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余数据。其次使用PP-OCRv3作为高精度模型,对剩余数据进行预测,去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。 -使用该策略,千万级别训练数据被精简至百万级,显著提升模型训练效率,模型训练时间从2周减少到5天,同时精度提升至72.7%(+1.2%)。 +DF(Data Filter) 是一种简单有效的数据挖掘方案。核心思想是利用已有模型预测训练数据,通过置信度和预测结果等信息,对全量的训练数据进行筛选。具体的:首先使用少量数据快速训练得到一个低精度模型,使用该低精度模型对千万级的数据进行预测,去除置信度大于0.95的样本,该部分被认为是对提升模型精度无效的冗余样本。其次使用PP-OCRv3作为高精度模型,对剩余数据进行预测,去除置信度小于0.15的样本,该部分被认为是难以识别或质量很差的样本。 +使用该策略,千万级别训练数据被精简至百万级,模型训练时间从2周减少到5天,显著提升了训练效率,同时精度提升至72.7%(+1.2%)。
@@ -118,12 +118,12 @@ PP-LCNetV3系列模型是PP-LCNet系列模型的延续,覆盖了更大的精 **(3)Lite-Neck:精简参数的Neck结构** -Lite-Neck整体结构沿用PP-OCRv3版本,在参数上稍作精简,识别模型整体的模型大小可从12M降低到8.5M,而精度不变;在CTCHead中,将Neck输出特征的维度从64提升到120,此时模型大小从8.5M提升到9.6M,精度提升0.5%。 +Lite-Neck整体结构沿用PP-OCRv3版本的结构,在参数上稍作精简,识别模型整体的模型大小可从12M降低到8.5M,而精度不变;在CTCHead中,将Neck输出特征的维度从64提升到120,此时模型大小从8.5M提升到9.6M。 **(4)GTC-NRTR:Attention指导CTC训练策略** -GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合多种文本特征的表达,有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导,相比SAR基于循环神经网络的结构,NRTR基于Transformer实现解码过程泛化能力更强,能有效指导CTC分支学习。解决简单场景下快速过拟合的问题。模型大小不变,识别精度提升至73.21%(+0.5%)。 +GTC(Guided Training of CTC),是PP-OCRv3识别模型的最有效的策略之一,融合多种文本特征的表达,有效的提升文本识别精度。在PP-OCRv4中使用训练更稳定的Transformer模型NRTR作为指导分支,相比V3版本中的SAR基于循环神经网络的结构,NRTR基于Transformer实现解码过程泛化能力更强,能有效指导CTC分支学习,解决简单场景下快速过拟合的问题。使用Lite-Neck和GTC-NRTR两个策略,识别精度提升至73.21%(+0.5%)。
@@ -132,7 +132,7 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合 **(5)Multi-Scale:多尺度训练策略** -动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增大模型的鲁棒性。在训练过程中随机选择(32,48,64)三种高度进行resize,实验证明在测试集上评估精度不掉,在端到端串联推理时,指标可以提升0.5%。 +动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增强识别模型在端到端串联使用时的鲁棒性。在训练时,每个iter从(32,48,64)三种高度中随机选择一种高度进行resize。实验证明,使用该策略,尽管在识别测试集上准确率没有提升,但在端到端串联评估时,指标提升0.5%。
@@ -143,9 +143,9 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合 识别模型的蒸馏包含两个部分,NRTRhead蒸馏和CTCHead蒸馏; -对于NRTR head,使用了DKD loss蒸馏,使学生模型NRTR head输出的logits与教师NRTR head接近。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。 +对于NRTR head,使用了DKD loss蒸馏,拉近学生模型和教师模型的NRTR head logits。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。 -对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从0.7377提升到0.7545。 +对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从74.72%提升到75.45%。 @@ -169,11 +169,11 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合 | PP-OCRv3_en | 64.04% | | PP-OCRv4_en | 70.1% | -同时,也对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升5%以上,如下表所示: +同时,对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升8%以上,如下表所示: | Model | 拉丁语系 | 阿拉伯语系 | 日语 | 韩语 | |-----|-----|--------|----| --- | | PP-OCR_mul | 69.60% | 40.50% | 38.50% | 55.40% | -| PP-OCRv3_mul | 75.20%| 45.37% | 45.80% | 60.10% | +| PP-OCRv3_mul | 71.57%| 72.90% | 45.85% | 77.23% | | PP-OCRv4_mul | 80.00%| 75.48% | 56.50% | 83.25% | diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index ed556ed9c9..a96bfeb154 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -86,6 +86,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 - [x] [SPIN](./algorithm_rec_spin.md) - [x] [RobustScanner](./algorithm_rec_robustscanner.md) - [x] [RFL](./algorithm_rec_rfl.md) +- [x] [ParseQ](./algorithm_rec_parseq.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -110,6 +111,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) | +|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) | diff --git a/doc/doc_ch/algorithm_rec_parseq.md b/doc/doc_ch/algorithm_rec_parseq.md new file mode 100644 index 0000000000..7853a9df8d --- /dev/null +++ b/doc/doc_ch/algorithm_rec_parseq.md @@ -0,0 +1,124 @@ +# ParseQ + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966) +> Darwin Bautista, Rowel Atienza +> ECCV, 2021 + +原论文分别使用真实文本识别数据集(Real)和合成文本识别数据集(Synth)进行训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估。 +其中: +- 真实文本识别数据集(Real)包含COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR, OpenVINO数据集 +- 合成文本识别数据集(Synth)包含MJSynth和SynthText数据集 + +在不同数据集上训练的算法的复现效果如下: + +|数据集|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | --- | +|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)| +|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)| +||||||| + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 + +训练 + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: + +``` +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_vit_parseq.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml +``` + +评估 + +``` +# GPU 评估, Global.pretrained_model 为待测权重 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +预测: + +``` +# 预测使用的配置文件必须与训练一致 +python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将ParseQ文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ),可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq +``` + +ParseQ文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False +``` + + +### 4.2 C++推理 + +由于C++预处理后处理还未支持ParseQ,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@InProceedings{bautista2022parseq, + title={Scene Text Recognition with Permuted Autoregressive Sequence Models}, + author={Bautista, Darwin and Atienza, Rowel}, + booktitle={European Conference on Computer Vision}, + pages={178--196}, + month={10}, + year={2022}, + publisher={Springer Nature Switzerland}, + address={Cham}, + doi={10.1007/978-3-031-19815-1_11}, + url={https://doi.org/10.1007/978-3-031-19815-1_11} +} +``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 2e25746dc0..3527e99e71 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -41,8 +41,8 @@ On the ICDAR2015 dataset, the text detection result is as follows: |DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| -|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| -|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| +|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)| +|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)| |DB++|ResNet50|90.89%|82.66%|86.58%|[pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)| On Total-Text dataset, the text detection result is as follows: @@ -83,6 +83,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [SPIN](./algorithm_rec_spin_en.md) - [x] [RobustScanner](./algorithm_rec_robustscanner_en.md) - [x] [RFL](./algorithm_rec_rfl_en.md) +- [x] [ParseQ](./algorithm_rec_parseq.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -107,6 +108,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| |RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) | +|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) | + diff --git a/doc/doc_en/algorithm_rec_parseq_en.md b/doc/doc_en/algorithm_rec_parseq_en.md new file mode 100644 index 0000000000..a2f8948e5b --- /dev/null +++ b/doc/doc_en/algorithm_rec_parseq_en.md @@ -0,0 +1,123 @@ +# PasreQ + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966) +> Darwin Bautista, Rowel Atienza +> ECCV, 2021 + +Using real datasets (real) and synthetic datsets (synth) for training respectively,and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. +- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets. +- The synthesis datasets include MJSynth and SynthText datasets. + +the algorithm reproduction effect is as follows: + +|Training Dataset|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | --- | +|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)| +|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)| +||||||| + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_vit_parseq.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the SAR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ), you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq +``` + +For SAR text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@InProceedings{bautista2022parseq, + title={Scene Text Recognition with Permuted Autoregressive Sequence Models}, + author={Bautista, Darwin and Atienza, Rowel}, + booktitle={European Conference on Computer Vision}, + pages={178--196}, + month={10}, + year={2022}, + publisher={Springer Nature Switzerland}, + address={Cham}, + doi={10.1007/978-3-031-19815-1_11}, + url={https://doi.org/10.1007/978-3-031-19815-1_11} +} +``` diff --git a/paddleocr.py b/paddleocr.py index 95134b5c16..dc92cbf6b7 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -513,7 +513,7 @@ def get_model_config(type, version, model_type, lang): def img_decode(content: bytes): np_arr = np.frombuffer(content, dtype=np.uint8) - return cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) def check_img(img): diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 121582b490..1eb611f6c0 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -27,7 +27,7 @@ from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \ - RFLRecResizeImg, SVTRRecAug + RFLRecResizeImg, SVTRRecAug, ParseQRecAug from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/abinet_aug.py b/ppocr/data/imaug/abinet_aug.py index bcbdadb1ba..9e1b6a6ce9 100644 --- a/ppocr/data/imaug/abinet_aug.py +++ b/ppocr/data/imaug/abinet_aug.py @@ -316,6 +316,35 @@ def __call__(self, img): img = np.clip(img + noise, 0, 255).astype(np.uint8) return img +class CVPossionNoise(object): + def __init__(self, lam=20): + self.lam = lam + if isinstance(lam, numbers.Number): + self.lam = max(int(sample_asym(lam)), 1) + elif isinstance(lam, (tuple, list)) and len(lam) == 2: + self.lam = int(sample_uniform(lam[0], lam[1])) + else: + raise Exception('lam must be number or list with length 2') + + def __call__(self, img): + noise = np.random.poisson(lam=self.lam, size=img.shape) + img = np.clip(img + noise, 0, 255).astype(np.uint8) + return img + +class CVGaussionBlur(object): + def __init__(self, radius): + self.radius = radius + if isinstance(radius, numbers.Number): + self.radius = max(int(sample_asym(radius)), 1) + elif isinstance(radius, (tuple, list)) and len(radius) == 2: + self.radius = int(sample_uniform(radius[0], radius[1])) + else: + raise Exception('radius must be number or list with length 2') + + def __call__(self, img): + fil = cv2.getGaussianKernel(ksize=self.radius, sigma=1, ktype=cv2.CV_32F) + img = cv2.sepFilter2D(img, -1, fil, fil) + return img class CVMotionBlur(object): def __init__(self, degrees=12, angle=90): @@ -427,6 +456,29 @@ def __call__(self, img): else: return img +class ParseQDeterioration(object): + def __init__(self, var, degrees, lam, radius, factor, p=0.5): + self.p = p + transforms = [] + if var is not None: + transforms.append(CVGaussianNoise(var=var)) + if degrees is not None: + transforms.append(CVMotionBlur(degrees=degrees)) + if lam is not None: + transforms.append(CVPossionNoise(lam=lam)) + if radius is not None: + transforms.append(CVGaussionBlur(radius=radius)) + if factor is not None: + transforms.append(CVRescale(factor=factor)) + self.transforms = transforms + + def __call__(self, img): + if random.random() < self.p: + random.shuffle(self.transforms) + transforms = Compose(self.transforms) + return transforms(img) + else: + return img class SVTRGeometry(object): def __init__(self, diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 148b093687..7be54aecaa 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1305,6 +1305,37 @@ def add_special_char(self, dict_character): dict_character = ['blank', '', '', ''] + dict_character return dict_character +class ParseQLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + + super(ParseQLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len - 2: + return None + data['length'] = np.array(len(text)) + text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]] + text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def add_special_char(self, dict_character): + dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD] + return dict_character class ViTSTRLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 9780082f1c..264579c038 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -20,7 +20,7 @@ from PIL import Image import PIL from .text_image_aug import tia_perspective, tia_stretch, tia_distort -from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration +from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration, ParseQDeterioration from paddle.vision.transforms import Compose @@ -204,6 +204,36 @@ def __call__(self, data): data['image'] = img return data +class ParseQRecAug(object): + def __init__(self, + aug_type=0, + geometry_p=0.5, + deterioration_p=0.25, + colorjitter_p=0.25, + **kwargs): + self.transforms = Compose([ + SVTRGeometry( + aug_type=aug_type, + degrees=45, + translate=(0.0, 0.0), + scale=(0.5, 2.), + shear=(45, 15), + distortion=0.5, + p=geometry_p), ParseQDeterioration( + var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p), + CVColorJitter( + brightness=0.5, + contrast=0.5, + saturation=0.5, + hue=0.1, + p=colorjitter_p) + ]) + + def __call__(self, data): + img = data['image'] + img = self.transforms(img) + data['image'] = img + return data class ClsResizeImg(object): def __init__(self, image_shape, **kwargs): diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 9e6a45478e..3a766f7a7c 100644 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -43,6 +43,7 @@ from .rec_can_loss import CANLoss from .rec_satrn_loss import SATRNLoss from .rec_nrtr_loss import NRTRLoss +from .rec_parseq_loss import ParseQLoss # cls loss from .cls_loss import ClsLoss @@ -76,7 +77,7 @@ def build_loss(config): 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss', - 'SATRNLoss', 'NRTRLoss' + 'SATRNLoss', 'NRTRLoss', 'ParseQLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_parseq_loss.py b/ppocr/losses/rec_parseq_loss.py new file mode 100644 index 0000000000..c2468b091a --- /dev/null +++ b/ppocr/losses/rec_parseq_loss.py @@ -0,0 +1,50 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class ParseQLoss(nn.Layer): + def __init__(self, **kwargs): + super(ParseQLoss, self).__init__() + + def forward(self, predicts, targets): + label = targets[1] # label + label_len = targets[2] + max_step = paddle.max(label_len).cpu().numpy()[0] + 2 + tgt = label[:, :max_step] + + logits_list = predicts['logits_list'] + pad_id = predicts['pad_id'] + eos_id = predicts['eos_id'] + + tgt_out = tgt[:, 1:] + loss = 0 + loss_numel = 0 + n = (tgt_out != pad_id).sum().item() + + for i, logits in enumerate(logits_list): + loss += n * paddle.nn.functional.cross_entropy(input=logits, label=tgt_out.flatten(), ignore_index=pad_id) + loss_numel += n + if i == 1: + tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out) + n = (tgt_out != pad_id).sum().item() + loss /= loss_numel + + return {'loss': loss} diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 873e8f6de1..60b9daf98a 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -50,11 +50,12 @@ def build_backbone(config, model_type): from .rec_shallow_cnn import ShallowCNN from .rec_lcnetv3 import PPLCNetV3 from .rec_hgnet import PPHGNet_small + from .rec_vit_parseq import ViTParseQ support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', - 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small' + 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_vit_parseq.py b/ppocr/modeling/backbones/rec_vit_parseq.py new file mode 100644 index 0000000000..403d122cb7 --- /dev/null +++ b/ppocr/modeling/backbones/rec_vit_parseq.py @@ -0,0 +1,304 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code is refer from: +https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.5/ppcls/arch/backbone/model_zoo/vision_transformer.py +""" + +from collections.abc import Callable + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal + + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def to_2tuple(x): + return tuple([x] * 2) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + # B= paddle.shape(x)[0] + N, C = x.shape[1:] + qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C)) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-5): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm1 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + elif isinstance(norm_layer, Callable): + self.norm2 = norm_layer(dim) + else: + raise TypeError( + "The norm_layer must be str or paddle.nn.layer.Layer class") + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + if isinstance(patch_size, int): + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = self.proj(x).flatten(2).transpose((0, 2, 1)) + return x + + +class VisionTransformer(nn.Layer): + """ Vision Transformer with support for patch input + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + class_num=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer='nn.LayerNorm', + epsilon=1e-5, + **kwargs): + super().__init__() + self.class_num = class_num + + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_channels, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=zeros_) + self.add_parameter("pos_embed", self.pos_embed) + self.cls_token = self.create_parameter( + shape=(1, 1, embed_dim), default_initializer=zeros_) + self.add_parameter("cls_token", self.cls_token) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = np.linspace(0, drop_path_rate, depth) + + self.blocks = nn.LayerList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + epsilon=epsilon) for i in range(depth) + ]) + + self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon) + + # Classifier head + self.head = nn.Linear(embed_dim, + class_num) if class_num > 0 else Identity() + + trunc_normal_(self.pos_embed) + self.out_channels = embed_dim + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + B = paddle.shape(x)[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +class ViTParseQ(VisionTransformer): + def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0): + super().__init__(img_size, patch_size, in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, class_num=0) + + def forward(self, x): + return self.forward_features(x) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 440d9e0293..ade67973c6 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -40,6 +40,7 @@ def build_head(config): from .rec_rfl_head import RFLHead from .rec_can_head import CANHead from .rec_satrn_head import SATRNHead + from .rec_parseq_head import ParseQHead # cls head from .cls_head import ClsHead @@ -56,7 +57,7 @@ def build_head(config): 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead', - 'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal' + 'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal', 'ParseQHead' ] if config['name'] == 'DRRGHead': diff --git a/ppocr/modeling/heads/rec_parseq_head.py b/ppocr/modeling/heads/rec_parseq_head.py new file mode 100644 index 0000000000..c68c0a4de6 --- /dev/null +++ b/ppocr/modeling/heads/rec_parseq_head.py @@ -0,0 +1,342 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py +# reference: https://arxiv.org/abs/2207.06966 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import numpy as np +from .self_attention import WrapEncoderForFeature +from .self_attention import WrapEncoder +from collections import OrderedDict +from typing import Optional +import copy +from itertools import permutations + + +class DecoderLayer(paddle.nn.Layer): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05): + super().__init__() + self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式 + self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) + self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward) + self.dropout = paddle.nn.Dropout(p=dropout) + self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model) + self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps) + self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps) + self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps) + self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps) + self.dropout1 = paddle.nn.Dropout(p=dropout) + self.dropout2 = paddle.nn.Dropout(p=dropout) + self.dropout3 = paddle.nn.Dropout(p=dropout) + if activation == 'gelu': + self.activation = paddle.nn.GELU() + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = paddle.nn.functional.gelu + super().__setstate__(state) + + def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + if tgt_key_padding_mask is not None: + tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False) + tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1) + else: + tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask) + + tgt = tgt + self.dropout1(tgt2) + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0] + return query, content + + +def get_clones(module, N): + return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)]) + + +class Decoder(paddle.nn.Layer): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last) + query = self.norm(query) + return query + + +class TokenEmbedding(paddle.nn.Layer): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: paddle.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64)) + + +def trunc_normal_init(param, **kwargs): + initializer = nn.initializer.TruncatedNormal(**kwargs) + initializer(param, param.block) + + +def constant_init(param, **kwargs): + initializer = nn.initializer.Constant(**kwargs) + initializer(param, param.block) + + +def kaiming_normal_init(param, **kwargs): + initializer = nn.initializer.KaimingNormal(**kwargs) + initializer(param, param.block) + + +class ParseQHead(nn.Layer): + def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs): + super().__init__() + + self.bos_id = out_channels - 2 + self.eos_id = 0 + self.pad_id = out_channels - 1 + + self.max_label_length = max_text_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim)) + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2) + self.text_embed = TokenEmbedding(out_channels, embed_dim) + self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim]))) + self.pos_queries.stop_gradient = not True + self.dropout = paddle.nn.Dropout(p=dropout) + self._device = self.parameters()[0].place + trunc_normal_init(self.pos_queries, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, paddle.nn.Linear): + trunc_normal_init(m.weight, std=0.02) + if m.bias is not None: + constant_init(m.bias, value=0.0) + elif isinstance(m, paddle.nn.Embedding): + trunc_normal_init(m.weight, std=0.02) + if m._padding_idx is not None: + m.weight.data[m._padding_idx].zero_() + elif isinstance(m, paddle.nn.Conv2D): + kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu') + if m.bias is not None: + constant_init(m.bias, value=0.0) + elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)): + constant_init(m.weight, value=1.0) + constant_init(m.bias, value=0.0) + + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + enc_param_names = {('encoder.' + n) for n in self.encoder. + no_weight_decay()} + return param_names.union(enc_param_names) + + def encode(self, img): + return self.encoder(img) + + def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None): + N, L = tgt.shape + null_ctx = self.text_embed(tgt[:, :1]) + if L != 1: + tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1)) + else: + tgt_emb = self.dropout(null_ctx) + if tgt_query is None: + tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1]) + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) + + def forward_test(self, memory, max_length=None): + testing = max_length is None + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + bs = memory.shape[0] + num_steps = max_length + 1 + + pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1]) + tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1) + if self.decode_ar: + tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64') + tgt_in[:, (0)] = self.bos_id + + logits = [] + for i in range(paddle.to_tensor(num_steps)): + j = i + 1 + tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j]) + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1) + if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all(): + break + logits = paddle.concat(x=logits, axis=1) + else: + tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64') + tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries) + logits = self.head(tgt_out) + if self.refine_iters: + temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2) + posi = np.where(temp.cpu().numpy()==True) + query_mask[posi] = 0 + bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64') + for i in range(self.refine_iters): + tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1) + tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32') + tgt_padding_mask = tgt_padding_mask.cpu() + tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0 + tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0 + tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]]) + logits = self.head(tgt_out) + + final_output = {"predict":logits} + + return final_output + + def gen_tgt_perms(self, tgt): + """Generate shared permutations for the whole batch. + This works because the same attention mask can be used for the shorter sequences + because of the padding mask. + """ + max_num_chars = tgt.shape[1] - 2 + if max_num_chars == 1: + return paddle.arange(end=3).unsqueeze(axis=0) + perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else [] + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + if max_num_chars < 5: + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector] + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = paddle.stack(x=perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), size=num_gen_perms - + len(perms), replace=False) + perms = paddle.concat(x=[perms, perm_pool[i]]) + else: + perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))]) + perms = paddle.stack(x=perms) + if self.perm_mirrored: + comp = perms.flip(axis=-1) + x = paddle.stack(x=[perms, comp]) + perm_2 = list(range(x.ndim)) + perm_2[0] = 1 + perm_2[1] = 0 + perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars)) + bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype) + eos_idx = paddle.full(shape=(len(perms), 1), fill_value= + max_num_chars + 1, dtype=perms.dtype) + perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1) + if len(perms) > 1: + perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1) + return perms + + def generate_attn_masks(self, perm): + """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens) + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = paddle.zeros(shape=(sz, sz)) + for i in range(sz): + query_idx = perm[i].cpu().numpy().tolist() + masked_keys = perm[i + 1:].cpu().numpy().tolist() + if len(masked_keys) == 0: + break + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf') + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def forward_train(self, memory, tgt): + tgt_perms = self.gen_tgt_perms(tgt) + tgt_in = tgt[:, :-1] + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + logits_list = [] + final_out = {} + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks(perm) + out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask) + logits = self.head(out) + if i == 0: + final_out['predict'] = logits + logits = logits.flatten(stop_axis=1) + logits_list.append(logits) + + final_out['logits_list'] = logits_list + final_out['pad_id'] = self.pad_id + final_out['eos_id'] = self.eos_id + + return final_out + + def forward(self, feat, targets=None): + # feat : B, N, C + # targets : labels, labels_len + + if self.training: + label = targets[0] # label + label_len = targets[1] + max_step = paddle.max(label_len).cpu().numpy()[0] + 2 + crop_label = label[:, :max_step] + final_out = self.forward_train(feat, crop_label) + else: + final_out = self.forward_test(feat) + + return final_out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c89345e70b..d5093e6c0d 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,7 +28,7 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ - SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode + SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode, ParseQLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess @@ -53,7 +53,7 @@ def build_post_process(config, global_config=None): 'DistillationSerPostProcess', 'DistillationRePostProcess', 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode', - 'SATRNLabelDecode' + 'SATRNLabelDecode', 'ParseQLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index ce2e9f8b57..bdadb9ed56 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -559,6 +559,95 @@ def get_beg_end_flag_idx(self, beg_or_end): % beg_or_end return idx +class ParseQLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(ParseQLabelDecode, self).__init__(character_dict_path, + use_space_char) + self.max_text_length = kwargs.get('max_text_length', 25) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, dict): + pred = preds['predict'] + else: + pred = preds + + char_num = len(self.character_str) + 1 # We don't predict nor , with only addition + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + B, L = pred.shape[:2] + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [B, L]) + preds_prob = np.reshape(preds_prob, [B, L]) + + if label is None: + text = self.decode(preds_idx, preds_prob, raw=False) + return text + + text = self.decode(preds_idx, preds_prob, raw=False) + label = self.decode(label, None, False) + + return text, label + + def decode(self, text_index, text_prob=None, raw=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + + index = text_index[batch_idx, :] + prob = None + if text_prob is not None: + prob = text_prob[batch_idx, :] + + if not raw: + index, prob = self._filter(index, prob) + + for idx in range(len(index)): + if index[idx] in ignored_tokens: + continue + char_list.append(self.character[int(index[idx])]) + if text_prob is not None: + conf_list.append(prob[idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + + return result_list + + def add_special_char(self, dict_character): + dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD] + return dict_character + + def _filter(self, ids, probs=None): + ids = ids.tolist() + try: + eos_idx = ids.index(self.dict[self.EOS]) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + if probs is not None: + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return ids, probs + + def get_ignored_tokens(self): + return [self.dict[self.BOS], self.dict[self.EOS], self.dict[self.PAD]] class SARLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/dict/parseq_dict.txt b/ppocr/utils/dict/parseq_dict.txt new file mode 100644 index 0000000000..1aef43d6b8 --- /dev/null +++ b/ppocr/utils/dict/parseq_dict.txt @@ -0,0 +1,94 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z +! +" +# +$ +% +& +' +( +) +* ++ +, +- +. +/ +: +; +< += +> +? +@ +[ +\ +] +^ +_ +` +{ +| +} +~ \ No newline at end of file diff --git a/test_tipc/test_train_inference_python_xpu.sh b/test_tipc/test_train_inference_python_xpu.sh index 7c6dc1e52a..c4add0056c 100644 --- a/test_tipc/test_train_inference_python_xpu.sh +++ b/test_tipc/test_train_inference_python_xpu.sh @@ -29,18 +29,28 @@ fi sed -i 's/use_gpu/use_xpu/g' $FILENAME # disable benchmark as AutoLog required nvidia-smi command sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME +# python has been updated to version 3.9 for xpu backend +sed -i "s/python3.7/python3.9/g" $FILENAME dataline=`cat $FILENAME` # parser params IFS=$'\n' lines=(${dataline}) +modelname=$(echo ${lines[1]} | cut -d ":" -f2) +if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then + sed -i "s/Global.epoch_num:lite_train_lite_infer=2/Global.epoch_num:lite_train_lite_infer=1/g" $FILENAME + sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME + sed -i "s/Global.use_xpu:True|True/Global.use_xpu:True/g" $FILENAME +fi + # replace training config file grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \ | while read line_num ; do train_cmd=$(func_parser_value "${lines[line_num-1]}") trainer_config=$(func_parser_config ${train_cmd}) sed -i 's/use_gpu/use_xpu/g' "$REPO_ROOT_PATH/$trainer_config" + sed -i 's/use_sync_bn: True/use_sync_bn: False/g' "$REPO_ROOT_PATH/$trainer_config" done # change gpu to xpu in execution script diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 9dd33dc7b6..1bb9aa9e40 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -122,6 +122,12 @@ def __init__(self, args): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + elif self.rec_algorithm == "ParseQ": + postprocess_params = { + 'name': 'ParseQLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.postprocess_params = postprocess_params self.predictor, self.input_tensor, self.output_tensors, self.config = \ @@ -439,7 +445,7 @@ def __call__(self, img_list): gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias2_list.append(norm_img[4]) norm_img_batch.append(norm_img[0]) - elif self.rec_algorithm in ["SVTR", "SATRN"]: + elif self.rec_algorithm in ["SVTR", "SATRN", "ParseQ"]: norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] diff --git a/tools/program.py b/tools/program.py index b01c2e43fe..d55961931d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -231,7 +231,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN", - "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet' + "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet', "ParseQ", ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -664,7 +664,7 @@ def preprocess(is_train=False): 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', - 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet' + 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet', 'ParseQ', ] if use_xpu: