@@ -132,7 +132,7 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
**(5)Multi-Scale:多尺度训练策略**
-动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增大模型的鲁棒性。在训练过程中随机选择(32,48,64)三种高度进行resize,实验证明在测试集上评估精度不掉,在端到端串联推理时,指标可以提升0.5%。
+动态尺度训练策略,是在训练过程中随机resize输入图片的高度,以增强识别模型在端到端串联使用时的鲁棒性。在训练时,每个iter从(32,48,64)三种高度中随机选择一种高度进行resize。实验证明,使用该策略,尽管在识别测试集上准确率没有提升,但在端到端串联评估时,指标提升0.5%。
@@ -143,9 +143,9 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
识别模型的蒸馏包含两个部分,NRTRhead蒸馏和CTCHead蒸馏;
-对于NRTR head,使用了DKD loss蒸馏,使学生模型NRTR head输出的logits与教师NRTR head接近。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。
+对于NRTR head,使用了DKD loss蒸馏,拉近学生模型和教师模型的NRTR head logits。最终NRTR head的loss是学生与教师间的DKD loss和与ground truth的cross entropy loss的加权和,用于监督学生模型的backbone训练。通过实验,我们发现加入DKD loss后,计算与ground truth的cross entropy loss时去除label smoothing可以进一步提高精度,因此我们在这里使用的是不带label smoothing的cross entropy loss。
-对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从0.7377提升到0.7545。
+对于CTCHead,由于CTC的输出中存在Blank位,即使教师模型和学生模型的预测结果一样,二者的输出的logits分布也会存在差异,影响教师模型向学生模型的知识传递。PP-OCRv4识别模型蒸馏策略中,将CTC输出logits沿着文本长度维度计算均值,将多字符识别问题转换为多字符分类问题,用于监督CTC Head的训练。使用该策略融合NRTRhead DKD蒸馏策略,指标从74.72%提升到75.45%。
@@ -169,11 +169,11 @@ GTC(Guided Training of CTC),是在PP-OCRv3中使用过的策略,融合
| PP-OCRv3_en | 64.04% |
| PP-OCRv4_en | 70.1% |
-同时,也对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升5%以上,如下表所示:
+同时,对已支持的80余种语言识别模型进行了升级更新,在有评估集的四种语系识别准确率平均提升8%以上,如下表所示:
| Model | 拉丁语系 | 阿拉伯语系 | 日语 | 韩语 |
|-----|-----|--------|----| --- |
| PP-OCR_mul | 69.60% | 40.50% | 38.50% | 55.40% |
-| PP-OCRv3_mul | 75.20%| 45.37% | 45.80% | 60.10% |
+| PP-OCRv3_mul | 71.57%| 72.90% | 45.85% | 77.23% |
| PP-OCRv4_mul | 80.00%| 75.48% | 56.50% | 83.25% |
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index ed556ed9c9..a96bfeb154 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -86,6 +86,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
- [x] [RFL](./algorithm_rec_rfl.md)
+- [x] [ParseQ](./algorithm_rec_parseq.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -110,6 +111,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
+|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
diff --git a/doc/doc_ch/algorithm_rec_parseq.md b/doc/doc_ch/algorithm_rec_parseq.md
new file mode 100644
index 0000000000..7853a9df8d
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_parseq.md
@@ -0,0 +1,124 @@
+# ParseQ
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
+> Darwin Bautista, Rowel Atienza
+> ECCV, 2021
+
+原论文分别使用真实文本识别数据集(Real)和合成文本识别数据集(Synth)进行训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估。
+其中:
+- 真实文本识别数据集(Real)包含COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR, OpenVINO数据集
+- 合成文本识别数据集(Synth)包含MJSynth和SynthText数据集
+
+在不同数据集上训练的算法的复现效果如下:
+
+|数据集|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- | --- |
+|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
+|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
+|||||||
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
+
+训练
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+
+```
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
+```
+
+评估
+
+```
+# GPU 评估, Global.pretrained_model 为待测权重
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+预测:
+
+```
+# 预测使用的配置文件必须与训练一致
+python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将ParseQ文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ),可以使用如下命令进行转换:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
+```
+
+ParseQ文本识别模型推理,可以执行如下命令:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
+```
+
+
+### 4.2 C++推理
+
+由于C++预处理后处理还未支持ParseQ,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@InProceedings{bautista2022parseq,
+ title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
+ author={Bautista, Darwin and Atienza, Rowel},
+ booktitle={European Conference on Computer Vision},
+ pages={178--196},
+ month={10},
+ year={2022},
+ publisher={Springer Nature Switzerland},
+ address={Cham},
+ doi={10.1007/978-3-031-19815-1_11},
+ url={https://doi.org/10.1007/978-3-031-19815-1_11}
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 2e25746dc0..3527e99e71 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -41,8 +41,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
-|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
-|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trianed model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
+|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
+|PSE|MobileNetV3|82.20%|70.48%|75.89%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|DB++|ResNet50|90.89%|82.66%|86.58%|[pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
On Total-Text dataset, the text detection result is as follows:
@@ -83,6 +83,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
- [x] [RFL](./algorithm_rec_rfl_en.md)
+- [x] [ParseQ](./algorithm_rec_parseq.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -107,6 +108,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)|
|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl_att_train.tar) |
+|ParseQ|VIT| 91.24% | rec_vit_parseq_synth | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz) |
+
diff --git a/doc/doc_en/algorithm_rec_parseq_en.md b/doc/doc_en/algorithm_rec_parseq_en.md
new file mode 100644
index 0000000000..a2f8948e5b
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_parseq_en.md
@@ -0,0 +1,123 @@
+# PasreQ
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966)
+> Darwin Bautista, Rowel Atienza
+> ECCV, 2021
+
+Using real datasets (real) and synthetic datsets (synth) for training respectively,and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets.
+- The real datasets include COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR and OpenVINO datasets.
+- The synthesis datasets include MJSynth and SynthText datasets.
+
+the algorithm reproduction effect is as follows:
+
+|Training Dataset|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- | --- |
+|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)|
+|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[train model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)|
+|||||||
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_vit_parseq.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the SAR text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ), you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq
+```
+
+For SAR text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+
+## Citation
+
+```bibtex
+@InProceedings{bautista2022parseq,
+ title={Scene Text Recognition with Permuted Autoregressive Sequence Models},
+ author={Bautista, Darwin and Atienza, Rowel},
+ booktitle={European Conference on Computer Vision},
+ pages={178--196},
+ month={10},
+ year={2022},
+ publisher={Springer Nature Switzerland},
+ address={Cham},
+ doi={10.1007/978-3-031-19815-1_11},
+ url={https://doi.org/10.1007/978-3-031-19815-1_11}
+}
+```
diff --git a/paddleocr.py b/paddleocr.py
index 95134b5c16..dc92cbf6b7 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -513,7 +513,7 @@ def get_model_config(type, version, model_type, lang):
def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
- return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
+ return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img):
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 121582b490..1eb611f6c0 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -27,7 +27,7 @@
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
- RFLRecResizeImg, SVTRRecAug
+ RFLRecResizeImg, SVTRRecAug, ParseQRecAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/abinet_aug.py b/ppocr/data/imaug/abinet_aug.py
index bcbdadb1ba..9e1b6a6ce9 100644
--- a/ppocr/data/imaug/abinet_aug.py
+++ b/ppocr/data/imaug/abinet_aug.py
@@ -316,6 +316,35 @@ def __call__(self, img):
img = np.clip(img + noise, 0, 255).astype(np.uint8)
return img
+class CVPossionNoise(object):
+ def __init__(self, lam=20):
+ self.lam = lam
+ if isinstance(lam, numbers.Number):
+ self.lam = max(int(sample_asym(lam)), 1)
+ elif isinstance(lam, (tuple, list)) and len(lam) == 2:
+ self.lam = int(sample_uniform(lam[0], lam[1]))
+ else:
+ raise Exception('lam must be number or list with length 2')
+
+ def __call__(self, img):
+ noise = np.random.poisson(lam=self.lam, size=img.shape)
+ img = np.clip(img + noise, 0, 255).astype(np.uint8)
+ return img
+
+class CVGaussionBlur(object):
+ def __init__(self, radius):
+ self.radius = radius
+ if isinstance(radius, numbers.Number):
+ self.radius = max(int(sample_asym(radius)), 1)
+ elif isinstance(radius, (tuple, list)) and len(radius) == 2:
+ self.radius = int(sample_uniform(radius[0], radius[1]))
+ else:
+ raise Exception('radius must be number or list with length 2')
+
+ def __call__(self, img):
+ fil = cv2.getGaussianKernel(ksize=self.radius, sigma=1, ktype=cv2.CV_32F)
+ img = cv2.sepFilter2D(img, -1, fil, fil)
+ return img
class CVMotionBlur(object):
def __init__(self, degrees=12, angle=90):
@@ -427,6 +456,29 @@ def __call__(self, img):
else:
return img
+class ParseQDeterioration(object):
+ def __init__(self, var, degrees, lam, radius, factor, p=0.5):
+ self.p = p
+ transforms = []
+ if var is not None:
+ transforms.append(CVGaussianNoise(var=var))
+ if degrees is not None:
+ transforms.append(CVMotionBlur(degrees=degrees))
+ if lam is not None:
+ transforms.append(CVPossionNoise(lam=lam))
+ if radius is not None:
+ transforms.append(CVGaussionBlur(radius=radius))
+ if factor is not None:
+ transforms.append(CVRescale(factor=factor))
+ self.transforms = transforms
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ random.shuffle(self.transforms)
+ transforms = Compose(self.transforms)
+ return transforms(img)
+ else:
+ return img
class SVTRGeometry(object):
def __init__(self,
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 148b093687..7be54aecaa 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -1305,6 +1305,37 @@ def add_special_char(self, dict_character):
dict_character = ['blank', '
', '', ''] + dict_character
return dict_character
+class ParseQLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+ BOS = '[B]'
+ EOS = '[E]'
+ PAD = '[P]'
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+
+ super(ParseQLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len - 2:
+ return None
+ data['length'] = np.array(len(text))
+ text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
+ text = text + [self.dict[self.PAD]] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
+ return dict_character
class ViTSTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 9780082f1c..264579c038 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -20,7 +20,7 @@
from PIL import Image
import PIL
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
-from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
+from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration, ParseQDeterioration
from paddle.vision.transforms import Compose
@@ -204,6 +204,36 @@ def __call__(self, data):
data['image'] = img
return data
+class ParseQRecAug(object):
+ def __init__(self,
+ aug_type=0,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs):
+ self.transforms = Compose([
+ SVTRGeometry(
+ aug_type=aug_type,
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p), ParseQDeterioration(
+ var=20, degrees=6, lam=20, radius=2.0, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p)
+ ])
+
+ def __call__(self, data):
+ img = data['image']
+ img = self.transforms(img)
+ data['image'] = img
+ return data
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 9e6a45478e..3a766f7a7c 100644
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -43,6 +43,7 @@
from .rec_can_loss import CANLoss
from .rec_satrn_loss import SATRNLoss
from .rec_nrtr_loss import NRTRLoss
+from .rec_parseq_loss import ParseQLoss
# cls loss
from .cls_loss import ClsLoss
@@ -76,7 +77,7 @@ def build_loss(config):
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss',
- 'SATRNLoss', 'NRTRLoss'
+ 'SATRNLoss', 'NRTRLoss', 'ParseQLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_parseq_loss.py b/ppocr/losses/rec_parseq_loss.py
new file mode 100644
index 0000000000..c2468b091a
--- /dev/null
+++ b/ppocr/losses/rec_parseq_loss.py
@@ -0,0 +1,50 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class ParseQLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(ParseQLoss, self).__init__()
+
+ def forward(self, predicts, targets):
+ label = targets[1] # label
+ label_len = targets[2]
+ max_step = paddle.max(label_len).cpu().numpy()[0] + 2
+ tgt = label[:, :max_step]
+
+ logits_list = predicts['logits_list']
+ pad_id = predicts['pad_id']
+ eos_id = predicts['eos_id']
+
+ tgt_out = tgt[:, 1:]
+ loss = 0
+ loss_numel = 0
+ n = (tgt_out != pad_id).sum().item()
+
+ for i, logits in enumerate(logits_list):
+ loss += n * paddle.nn.functional.cross_entropy(input=logits, label=tgt_out.flatten(), ignore_index=pad_id)
+ loss_numel += n
+ if i == 1:
+ tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out)
+ n = (tgt_out != pad_id).sum().item()
+ loss /= loss_numel
+
+ return {'loss': loss}
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 873e8f6de1..60b9daf98a 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -50,11 +50,12 @@ def build_backbone(config, model_type):
from .rec_shallow_cnn import ShallowCNN
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
+ from .rec_vit_parseq import ViTParseQ
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
- 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small'
+ 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_vit_parseq.py b/ppocr/modeling/backbones/rec_vit_parseq.py
new file mode 100644
index 0000000000..403d122cb7
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_vit_parseq.py
@@ -0,0 +1,304 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is refer from:
+https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.5/ppcls/arch/backbone/model_zoo/vision_transformer.py
+"""
+
+from collections.abc import Callable
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal
+
+
+trunc_normal_ = TruncatedNormal(std=.02)
+normal_ = Normal
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+def to_2tuple(x):
+ return tuple([x] * 2)
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
+ shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
+ random_tensor = paddle.floor(random_tensor) # binarize
+ output = x.divide(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Layer):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, input):
+ return input
+
+
+class Mlp(nn.Layer):
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Layer):
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ # B= paddle.shape(x)[0]
+ N, C = x.shape[1:]
+ qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
+ self.num_heads)).transpose((2, 0, 3, 1, 4))
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
+ attn = nn.functional.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class Block(nn.Layer):
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-5):
+ super().__init__()
+ if isinstance(norm_layer, str):
+ self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
+ elif isinstance(norm_layer, Callable):
+ self.norm1 = norm_layer(dim)
+ else:
+ raise TypeError(
+ "The norm_layer must be str or paddle.nn.layer.Layer class")
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+ if isinstance(norm_layer, str):
+ self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
+ elif isinstance(norm_layer, Callable):
+ self.norm2 = norm_layer(dim)
+ else:
+ raise TypeError(
+ "The norm_layer must be str or paddle.nn.layer.Layer class")
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Layer):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ if isinstance(img_size, int):
+ img_size = to_2tuple(img_size)
+ if isinstance(patch_size, int):
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2D(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+
+ x = self.proj(x).flatten(2).transpose((0, 2, 1))
+ return x
+
+
+class VisionTransformer(nn.Layer):
+ """ Vision Transformer with support for patch input
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ class_num=1000,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer='nn.LayerNorm',
+ epsilon=1e-5,
+ **kwargs):
+ super().__init__()
+ self.class_num = class_num
+
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_channels,
+ embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=zeros_)
+ self.add_parameter("pos_embed", self.pos_embed)
+ self.cls_token = self.create_parameter(
+ shape=(1, 1, embed_dim), default_initializer=zeros_)
+ self.add_parameter("cls_token", self.cls_token)
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = np.linspace(0, drop_path_rate, depth)
+
+ self.blocks = nn.LayerList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ epsilon=epsilon) for i in range(depth)
+ ])
+
+ self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
+
+ # Classifier head
+ self.head = nn.Linear(embed_dim,
+ class_num) if class_num > 0 else Identity()
+
+ trunc_normal_(self.pos_embed)
+ self.out_channels = embed_dim
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward_features(self, x):
+ B = paddle.shape(x)[0]
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+class ViTParseQ(VisionTransformer):
+ def __init__(self, img_size=[224, 224], patch_size=[16, 16], in_channels=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0):
+ super().__init__(img_size, patch_size, in_channels, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, class_num=0)
+
+ def forward(self, x):
+ return self.forward_features(x)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 440d9e0293..ade67973c6 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -40,6 +40,7 @@ def build_head(config):
from .rec_rfl_head import RFLHead
from .rec_can_head import CANHead
from .rec_satrn_head import SATRNHead
+ from .rec_parseq_head import ParseQHead
# cls head
from .cls_head import ClsHead
@@ -56,7 +57,7 @@ def build_head(config):
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead',
- 'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal'
+ 'DRRGHead', 'CANHead', 'SATRNHead', 'PFHeadLocal', 'ParseQHead'
]
if config['name'] == 'DRRGHead':
diff --git a/ppocr/modeling/heads/rec_parseq_head.py b/ppocr/modeling/heads/rec_parseq_head.py
new file mode 100644
index 0000000000..c68c0a4de6
--- /dev/null
+++ b/ppocr/modeling/heads/rec_parseq_head.py
@@ -0,0 +1,342 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
+# reference: https://arxiv.org/abs/2207.06966
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+from .self_attention import WrapEncoderForFeature
+from .self_attention import WrapEncoder
+from collections import OrderedDict
+from typing import Optional
+import copy
+from itertools import permutations
+
+
+class DecoderLayer(paddle.nn.Layer):
+ """A Transformer decoder layer supporting two-stream attention (XLNet)
+ This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
+
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-05):
+ super().__init__()
+ self.self_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True) # paddle.nn.MultiHeadAttention默认为batch_first模式
+ self.cross_attn = paddle.nn.MultiHeadAttention(d_model, nhead, dropout=dropout, need_weights=True)
+ self.linear1 = paddle.nn.Linear(in_features=d_model, out_features=dim_feedforward)
+ self.dropout = paddle.nn.Dropout(p=dropout)
+ self.linear2 = paddle.nn.Linear(in_features=dim_feedforward, out_features=d_model)
+ self.norm1 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
+ self.norm2 = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
+ self.norm_q = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
+ self.norm_c = paddle.nn.LayerNorm(normalized_shape=d_model, epsilon=layer_norm_eps)
+ self.dropout1 = paddle.nn.Dropout(p=dropout)
+ self.dropout2 = paddle.nn.Dropout(p=dropout)
+ self.dropout3 = paddle.nn.Dropout(p=dropout)
+ if activation == 'gelu':
+ self.activation = paddle.nn.GELU()
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = paddle.nn.functional.gelu
+ super().__setstate__(state)
+
+ def forward_stream(self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask):
+ """Forward pass for a single stream (i.e. content or query)
+ tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
+ Both tgt_kv and memory are expected to be LayerNorm'd too.
+ memory is LayerNorm'd by ViT.
+ """
+ if tgt_key_padding_mask is not None:
+ tgt_mask1 = (tgt_mask!=float('-inf'))[None,None,:,:] & (tgt_key_padding_mask[:,None,None,:]==False)
+ tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1)
+ else:
+ tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask)
+
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt, sa_weights, ca_weights
+
+ def forward(self, query, content, memory, query_mask=None, content_mask=None, content_key_padding_mask=None, update_content=True):
+ query_norm = self.norm_q(query)
+ content_norm = self.norm_c(content)
+ query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
+ if update_content:
+ content = self.forward_stream(content, content_norm, content_norm, memory, content_mask, content_key_padding_mask)[0]
+ return query, content
+
+
+def get_clones(module, N):
+ return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
+
+
+class Decoder(paddle.nn.Layer):
+ __constants__ = ['norm']
+
+ def __init__(self, decoder_layer, num_layers, norm):
+ super().__init__()
+ self.layers = get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self, query, content, memory, query_mask: Optional[paddle.Tensor]=None, content_mask: Optional[paddle.Tensor]=None, content_key_padding_mask: Optional[paddle.Tensor]=None):
+ for i, mod in enumerate(self.layers):
+ last = i == len(self.layers) - 1
+ query, content = mod(query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last)
+ query = self.norm(query)
+ return query
+
+
+class TokenEmbedding(paddle.nn.Layer):
+
+ def __init__(self, charset_size: int, embed_dim: int):
+ super().__init__()
+ self.embedding = paddle.nn.Embedding(num_embeddings=charset_size, embedding_dim=embed_dim)
+ self.embed_dim = embed_dim
+
+ def forward(self, tokens: paddle.Tensor):
+ return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
+
+
+def trunc_normal_init(param, **kwargs):
+ initializer = nn.initializer.TruncatedNormal(**kwargs)
+ initializer(param, param.block)
+
+
+def constant_init(param, **kwargs):
+ initializer = nn.initializer.Constant(**kwargs)
+ initializer(param, param.block)
+
+
+def kaiming_normal_init(param, **kwargs):
+ initializer = nn.initializer.KaimingNormal(**kwargs)
+ initializer(param, param.block)
+
+
+class ParseQHead(nn.Layer):
+ def __init__(self, out_channels, max_text_length, embed_dim, dec_num_heads, dec_mlp_ratio, dec_depth, perm_num, perm_forward, perm_mirrored, decode_ar, refine_iters, dropout, **kwargs):
+ super().__init__()
+
+ self.bos_id = out_channels - 2
+ self.eos_id = 0
+ self.pad_id = out_channels - 1
+
+ self.max_label_length = max_text_length
+ self.decode_ar = decode_ar
+ self.refine_iters = refine_iters
+ decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
+ self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=paddle.nn.LayerNorm(normalized_shape=embed_dim))
+ self.rng = np.random.default_rng()
+ self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
+ self.perm_forward = perm_forward
+ self.perm_mirrored = perm_mirrored
+ self.head = paddle.nn.Linear(in_features=embed_dim, out_features=out_channels - 2)
+ self.text_embed = TokenEmbedding(out_channels, embed_dim)
+ self.pos_queries = paddle.create_parameter(shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape, dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype, default_initializer=paddle.nn.initializer.Assign(paddle.empty(shape=[1, max_text_length + 1, embed_dim])))
+ self.pos_queries.stop_gradient = not True
+ self.dropout = paddle.nn.Dropout(p=dropout)
+ self._device = self.parameters()[0].place
+ trunc_normal_init(self.pos_queries, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, paddle.nn.Linear):
+ trunc_normal_init(m.weight, std=0.02)
+ if m.bias is not None:
+ constant_init(m.bias, value=0.0)
+ elif isinstance(m, paddle.nn.Embedding):
+ trunc_normal_init(m.weight, std=0.02)
+ if m._padding_idx is not None:
+ m.weight.data[m._padding_idx].zero_()
+ elif isinstance(m, paddle.nn.Conv2D):
+ kaiming_normal_init(m.weight, fan_in=None, nonlinearity='relu')
+ if m.bias is not None:
+ constant_init(m.bias, value=0.0)
+ elif isinstance(m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)):
+ constant_init(m.weight, value=1.0)
+ constant_init(m.bias, value=0.0)
+
+ def no_weight_decay(self):
+ param_names = {'text_embed.embedding.weight', 'pos_queries'}
+ enc_param_names = {('encoder.' + n) for n in self.encoder.
+ no_weight_decay()}
+ return param_names.union(enc_param_names)
+
+ def encode(self, img):
+ return self.encoder(img)
+
+ def decode(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, tgt_query=None, tgt_query_mask=None):
+ N, L = tgt.shape
+ null_ctx = self.text_embed(tgt[:, :1])
+ if L != 1:
+ tgt_emb = self.pos_queries[:, :L - 1] + self.text_embed(tgt[:, 1:])
+ tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
+ else:
+ tgt_emb = self.dropout(null_ctx)
+ if tgt_query is None:
+ tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
+ tgt_query = self.dropout(tgt_query)
+ return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
+
+ def forward_test(self, memory, max_length=None):
+ testing = max_length is None
+ max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
+ bs = memory.shape[0]
+ num_steps = max_length + 1
+
+ pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
+ tgt_mask = query_mask = paddle.triu(x=paddle.full(shape=(num_steps, num_steps), fill_value=float('-inf')), diagonal=1)
+ if self.decode_ar:
+ tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype('int64')
+ tgt_in[:, (0)] = self.bos_id
+
+ logits = []
+ for i in range(paddle.to_tensor(num_steps)):
+ j = i + 1
+ tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
+ p_i = self.head(tgt_out)
+ logits.append(p_i)
+ if j < num_steps:
+ tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
+ if testing and (tgt_in == self.eos_id).astype('bool').any(axis=-1).astype('bool').all():
+ break
+ logits = paddle.concat(x=logits, axis=1)
+ else:
+ tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
+ tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
+ logits = self.head(tgt_out)
+ if self.refine_iters:
+ temp = paddle.triu(x=paddle.ones(shape=[num_steps,num_steps], dtype='bool'), diagonal=2)
+ posi = np.where(temp.cpu().numpy()==True)
+ query_mask[posi] = 0
+ bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype('int64')
+ for i in range(self.refine_iters):
+ tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
+ tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype='int32')
+ tgt_padding_mask = tgt_padding_mask.cpu()
+ tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
+ tgt_padding_mask = tgt_padding_mask.cuda().astype(dtype='float32')==1.0
+ tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
+ logits = self.head(tgt_out)
+
+ final_output = {"predict":logits}
+
+ return final_output
+
+ def gen_tgt_perms(self, tgt):
+ """Generate shared permutations for the whole batch.
+ This works because the same attention mask can be used for the shorter sequences
+ because of the padding mask.
+ """
+ max_num_chars = tgt.shape[1] - 2
+ if max_num_chars == 1:
+ return paddle.arange(end=3).unsqueeze(axis=0)
+ perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
+ max_perms = math.factorial(max_num_chars)
+ if self.perm_mirrored:
+ max_perms //= 2
+ num_gen_perms = min(self.max_gen_perms, max_perms)
+ if max_num_chars < 5:
+ if max_num_chars == 4 and self.perm_mirrored:
+ selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
+ else:
+ selector = list(range(max_perms))
+ perm_pool = paddle.to_tensor(data=list(permutations(range(max_num_chars), max_num_chars)), place=self._device)[selector]
+ if self.perm_forward:
+ perm_pool = perm_pool[1:]
+ perms = paddle.stack(x=perms)
+ if len(perm_pool):
+ i = self.rng.choice(len(perm_pool), size=num_gen_perms -
+ len(perms), replace=False)
+ perms = paddle.concat(x=[perms, perm_pool[i]])
+ else:
+ perms.extend([paddle.randperm(n=max_num_chars) for _ in range(num_gen_perms - len(perms))])
+ perms = paddle.stack(x=perms)
+ if self.perm_mirrored:
+ comp = perms.flip(axis=-1)
+ x = paddle.stack(x=[perms, comp])
+ perm_2 = list(range(x.ndim))
+ perm_2[0] = 1
+ perm_2[1] = 0
+ perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
+ bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
+ eos_idx = paddle.full(shape=(len(perms), 1), fill_value=
+ max_num_chars + 1, dtype=perms.dtype)
+ perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
+ if len(perms) > 1:
+ perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
+ return perms
+
+ def generate_attn_masks(self, perm):
+ """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
+ :param perm: the permutation sequence. i = 0 is always the BOS
+ :return: lookahead attention masks
+ """
+ sz = perm.shape[0]
+ mask = paddle.zeros(shape=(sz, sz))
+ for i in range(sz):
+ query_idx = perm[i].cpu().numpy().tolist()
+ masked_keys = perm[i + 1:].cpu().numpy().tolist()
+ if len(masked_keys) == 0:
+ break
+ mask[query_idx, masked_keys] = float('-inf')
+ content_mask = mask[:-1, :-1].clone()
+ mask[paddle.eye(num_rows=sz).astype('bool')] = float('-inf')
+ query_mask = mask[1:, :-1]
+ return content_mask, query_mask
+
+ def forward_train(self, memory, tgt):
+ tgt_perms = self.gen_tgt_perms(tgt)
+ tgt_in = tgt[:, :-1]
+ tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
+ logits_list = []
+ final_out = {}
+ for i, perm in enumerate(tgt_perms):
+ tgt_mask, query_mask = self.generate_attn_masks(perm)
+ out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask)
+ logits = self.head(out)
+ if i == 0:
+ final_out['predict'] = logits
+ logits = logits.flatten(stop_axis=1)
+ logits_list.append(logits)
+
+ final_out['logits_list'] = logits_list
+ final_out['pad_id'] = self.pad_id
+ final_out['eos_id'] = self.eos_id
+
+ return final_out
+
+ def forward(self, feat, targets=None):
+ # feat : B, N, C
+ # targets : labels, labels_len
+
+ if self.training:
+ label = targets[0] # label
+ label_len = targets[1]
+ max_step = paddle.max(label_len).cpu().numpy()[0] + 2
+ crop_label = label[:, :max_step]
+ final_out = self.forward_train(feat, crop_label)
+ else:
+ final_out = self.forward_test(feat)
+
+ return final_out
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index c89345e70b..d5093e6c0d 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,7 +28,7 @@
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
- SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode
+ SPINLabelDecode, VLLabelDecode, RFLLabelDecode, SATRNLabelDecode, ParseQLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
@@ -53,7 +53,7 @@ def build_post_process(config, global_config=None):
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode',
- 'SATRNLabelDecode'
+ 'SATRNLabelDecode', 'ParseQLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index ce2e9f8b57..bdadb9ed56 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -559,6 +559,95 @@ def get_beg_end_flag_idx(self, beg_or_end):
% beg_or_end
return idx
+class ParseQLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+ BOS = '[B]'
+ EOS = '[E]'
+ PAD = '[P]'
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(ParseQLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+ self.max_text_length = kwargs.get('max_text_length', 25)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, dict):
+ pred = preds['predict']
+ else:
+ pred = preds
+
+ char_num = len(self.character_str) + 1 # We don't predict nor , with only addition
+ if isinstance(pred, paddle.Tensor):
+ pred = pred.numpy()
+ B, L = pred.shape[:2]
+ pred = np.reshape(pred, [-1, char_num])
+
+ preds_idx = np.argmax(pred, axis=1)
+ preds_prob = np.max(pred, axis=1)
+
+ preds_idx = np.reshape(preds_idx, [B, L])
+ preds_prob = np.reshape(preds_prob, [B, L])
+
+ if label is None:
+ text = self.decode(preds_idx, preds_prob, raw=False)
+ return text
+
+ text = self.decode(preds_idx, preds_prob, raw=False)
+ label = self.decode(label, None, False)
+
+ return text, label
+
+ def decode(self, text_index, text_prob=None, raw=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+
+ index = text_index[batch_idx, :]
+ prob = None
+ if text_prob is not None:
+ prob = text_prob[batch_idx, :]
+
+ if not raw:
+ index, prob = self._filter(index, prob)
+
+ for idx in range(len(index)):
+ if index[idx] in ignored_tokens:
+ continue
+ char_list.append(self.character[int(index[idx])])
+ if text_prob is not None:
+ conf_list.append(prob[idx])
+ else:
+ conf_list.append(1)
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+
+ return result_list
+
+ def add_special_char(self, dict_character):
+ dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
+ return dict_character
+
+ def _filter(self, ids, probs=None):
+ ids = ids.tolist()
+ try:
+ eos_idx = ids.index(self.dict[self.EOS])
+ except ValueError:
+ eos_idx = len(ids) # Nothing to truncate.
+ # Truncate after EOS
+ ids = ids[:eos_idx]
+ if probs is not None:
+ probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
+ return ids, probs
+
+ def get_ignored_tokens(self):
+ return [self.dict[self.BOS], self.dict[self.EOS], self.dict[self.PAD]]
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
diff --git a/ppocr/utils/dict/parseq_dict.txt b/ppocr/utils/dict/parseq_dict.txt
new file mode 100644
index 0000000000..1aef43d6b8
--- /dev/null
+++ b/ppocr/utils/dict/parseq_dict.txt
@@ -0,0 +1,94 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+{
+|
+}
+~
\ No newline at end of file
diff --git a/test_tipc/test_train_inference_python_xpu.sh b/test_tipc/test_train_inference_python_xpu.sh
index 7c6dc1e52a..c4add0056c 100644
--- a/test_tipc/test_train_inference_python_xpu.sh
+++ b/test_tipc/test_train_inference_python_xpu.sh
@@ -29,18 +29,28 @@ fi
sed -i 's/use_gpu/use_xpu/g' $FILENAME
# disable benchmark as AutoLog required nvidia-smi command
sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME
+# python has been updated to version 3.9 for xpu backend
+sed -i "s/python3.7/python3.9/g" $FILENAME
dataline=`cat $FILENAME`
# parser params
IFS=$'\n'
lines=(${dataline})
+modelname=$(echo ${lines[1]} | cut -d ":" -f2)
+if [ $modelname == "rec_r31_sar" ] || [ $modelname == "rec_mtb_nrtr" ]; then
+ sed -i "s/Global.epoch_num:lite_train_lite_infer=2/Global.epoch_num:lite_train_lite_infer=1/g" $FILENAME
+ sed -i "s/gpu_list:0|0,1/gpu_list:0,1/g" $FILENAME
+ sed -i "s/Global.use_xpu:True|True/Global.use_xpu:True/g" $FILENAME
+fi
+
# replace training config file
grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \
| while read line_num ; do
train_cmd=$(func_parser_value "${lines[line_num-1]}")
trainer_config=$(func_parser_config ${train_cmd})
sed -i 's/use_gpu/use_xpu/g' "$REPO_ROOT_PATH/$trainer_config"
+ sed -i 's/use_sync_bn: True/use_sync_bn: False/g' "$REPO_ROOT_PATH/$trainer_config"
done
# change gpu to xpu in execution script
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 9dd33dc7b6..1bb9aa9e40 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -122,6 +122,12 @@ def __init__(self, args):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ elif self.rec_algorithm == "ParseQ":
+ postprocess_params = {
+ 'name': 'ParseQLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
@@ -439,7 +445,7 @@ def __call__(self, img_list):
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
- elif self.rec_algorithm in ["SVTR", "SATRN"]:
+ elif self.rec_algorithm in ["SVTR", "SATRN", "ParseQ"]:
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
diff --git a/tools/program.py b/tools/program.py
index b01c2e43fe..d55961931d 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -231,7 +231,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SVTR_LCNet", "SPIN", "VisionLAN",
- "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet'
+ "RobustScanner", "RFL", 'DRRG', 'SATRN', 'SVTR_HGNet', "ParseQ",
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
@@ -664,7 +664,7 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'SVTR_LCNet', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN',
'VisionLAN', 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG',
- 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet'
+ 'CAN', 'Telescope', 'SATRN', 'SVTR_HGNet', 'ParseQ',
]
if use_xpu: