-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new recognition method "ParseQ" (#10836)
* Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
- Loading branch information
1 parent
ab86490
commit 75d1661
Showing
24 changed files
with
1,404 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
Global: | ||
use_gpu: True | ||
epoch_num: 20 | ||
log_smooth_window: 20 | ||
print_batch_step: 5 | ||
save_model_dir: ./output/rec/parseq | ||
save_epoch_step: 3 | ||
# evaluation is run every 5000 iterations after the 4000th iteration | ||
eval_batch_step: [0, 500] | ||
cal_metric_during_train: True | ||
pretrained_model: | ||
checkpoints: | ||
save_inference_dir: | ||
use_visualdl: False | ||
infer_img: doc/imgs_words_en/word_10.png | ||
# for data or label process | ||
character_dict_path: ppocr/utils/dict/parseq_dict.txt | ||
character_type: en | ||
max_text_length: 25 | ||
num_heads: 8 | ||
infer_mode: False | ||
use_space_char: False | ||
save_res_path: ./output/rec/predicts_parseq.txt | ||
|
||
|
||
Optimizer: | ||
name: Adam | ||
beta1: 0.9 | ||
beta2: 0.999 | ||
lr: | ||
name: OneCycle | ||
max_lr: 0.0007 | ||
|
||
Architecture: | ||
model_type: rec | ||
algorithm: ParseQ | ||
in_channels: 3 | ||
Transform: | ||
Backbone: | ||
name: ViTParseQ | ||
img_size: [32, 128] | ||
patch_size: [4, 8] | ||
embed_dim: 384 | ||
depth: 12 | ||
num_heads: 6 | ||
mlp_ratio: 4 | ||
in_channels: 3 | ||
Head: | ||
name: ParseQHead | ||
# Architecture | ||
max_text_length: 25 | ||
embed_dim: 384 | ||
dec_num_heads: 12 | ||
dec_mlp_ratio: 4 | ||
dec_depth: 1 | ||
# Training | ||
perm_num: 6 | ||
perm_forward: true | ||
perm_mirrored: true | ||
dropout: 0.1 | ||
# Decoding mode (test) | ||
decode_ar: true | ||
refine_iters: 1 | ||
|
||
Loss: | ||
name: ParseQLoss | ||
|
||
PostProcess: | ||
name: ParseQLabelDecode | ||
|
||
Metric: | ||
name: RecMetric | ||
main_indicator: acc | ||
is_filter: True | ||
|
||
Train: | ||
dataset: | ||
name: LMDBDataSet | ||
data_dir: | ||
transforms: | ||
- DecodeImage: # load image | ||
img_mode: BGR | ||
channel_first: False | ||
- ParseQRecAug: | ||
aug_type: 0 # or 1 | ||
- ParseQLabelEncode: | ||
- SVTRRecResizeImg: | ||
image_shape: [3, 32, 128] | ||
padding: False | ||
- KeepKeys: | ||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order | ||
loader: | ||
shuffle: True | ||
batch_size_per_card: 192 | ||
drop_last: True | ||
num_workers: 4 | ||
|
||
Eval: | ||
dataset: | ||
name: LMDBDataSet | ||
data_dir: | ||
transforms: | ||
- DecodeImage: # load image | ||
img_mode: BGR | ||
channel_first: False | ||
- ParseQLabelEncode: # Class handling label | ||
- SVTRRecResizeImg: | ||
image_shape: [3, 32, 128] | ||
padding: False | ||
- KeepKeys: | ||
keep_keys: ['image', 'label', 'length'] | ||
loader: | ||
shuffle: False | ||
drop_last: False | ||
batch_size_per_card: 384 | ||
num_workers: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# ParseQ | ||
|
||
- [1. 算法简介](#1) | ||
- [2. 环境配置](#2) | ||
- [3. 模型训练、评估、预测](#3) | ||
- [3.1 训练](#3-1) | ||
- [3.2 评估](#3-2) | ||
- [3.3 预测](#3-3) | ||
- [4. 推理部署](#4) | ||
- [4.1 Python推理](#4-1) | ||
- [4.2 C++推理](#4-2) | ||
- [4.3 Serving服务化部署](#4-3) | ||
- [4.4 更多推理部署](#4-4) | ||
- [5. FAQ](#5) | ||
|
||
<a name="1"></a> | ||
## 1. 算法简介 | ||
|
||
论文信息: | ||
> [Scene Text Recognition with Permuted Autoregressive Sequence Models](https://arxiv.org/abs/2207.06966) | ||
> Darwin Bautista, Rowel Atienza | ||
> ECCV, 2021 | ||
原论文分别使用真实文本识别数据集(Real)和合成文本识别数据集(Synth)进行训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估。 | ||
其中: | ||
- 真实文本识别数据集(Real)包含COCO-Text, RCTW17, Uber-Text, ArT, LSVT, MLT19, ReCTS, TextOCR, OpenVINO数据集 | ||
- 合成文本识别数据集(Synth)包含MJSynth和SynthText数据集 | ||
|
||
在不同数据集上训练的算法的复现效果如下: | ||
|
||
|数据集|模型|骨干网络|配置文件|Acc|下载链接| | ||
| --- | --- | --- | --- | --- | --- | | ||
|Synth|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|91.24%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_synth.tgz)| | ||
|Real|ParseQ|VIT|[rec_vit_parseq.yml](../../configs/rec/rec_vit_parseq.yml)|94.74%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz)| | ||
||||||| | ||
|
||
<a name="2"></a> | ||
## 2. 环境配置 | ||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 | ||
|
||
|
||
<a name="3"></a> | ||
## 3. 模型训练、评估、预测 | ||
|
||
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 | ||
|
||
训练 | ||
|
||
具体地,在完成数据准备后,便可以启动训练,训练命令如下: | ||
|
||
``` | ||
#单卡训练(训练周期长,不建议) | ||
python3 tools/train.py -c configs/rec/rec_vit_parseq.yml | ||
#多卡训练,通过--gpus参数指定卡号 | ||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_vit_parseq.yml | ||
``` | ||
|
||
评估 | ||
|
||
``` | ||
# GPU 评估, Global.pretrained_model 为待测权重 | ||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy | ||
``` | ||
|
||
预测: | ||
|
||
``` | ||
# 预测使用的配置文件必须与训练一致 | ||
python3 tools/infer_rec.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png | ||
``` | ||
|
||
<a name="4"></a> | ||
## 4. 推理部署 | ||
|
||
<a name="4-1"></a> | ||
### 4.1 Python推理 | ||
首先将ParseQ文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/parseq/rec_vit_parseq_real.tgz) ),可以使用如下命令进行转换: | ||
|
||
``` | ||
python3 tools/export_model.py -c configs/rec/rec_vit_parseq.yml -o Global.pretrained_model=./rec_vit_parseq_real/best_accuracy Global.save_inference_dir=./inference/rec_parseq | ||
``` | ||
|
||
ParseQ文本识别模型推理,可以执行如下命令: | ||
|
||
``` | ||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_parseq/" --rec_image_shape="3, 32, 128" --rec_algorithm="ParseQ" --rec_char_dict_path="ppocr/utils/dict/parseq_dict.txt" --max_text_length=25 --use_space_char=False | ||
``` | ||
|
||
<a name="4-2"></a> | ||
### 4.2 C++推理 | ||
|
||
由于C++预处理后处理还未支持ParseQ,所以暂未支持 | ||
|
||
<a name="4-3"></a> | ||
### 4.3 Serving服务化部署 | ||
|
||
暂不支持 | ||
|
||
<a name="4-4"></a> | ||
### 4.4 更多推理部署 | ||
|
||
暂不支持 | ||
|
||
<a name="5"></a> | ||
## 5. FAQ | ||
|
||
|
||
## 引用 | ||
|
||
```bibtex | ||
@InProceedings{bautista2022parseq, | ||
title={Scene Text Recognition with Permuted Autoregressive Sequence Models}, | ||
author={Bautista, Darwin and Atienza, Rowel}, | ||
booktitle={European Conference on Computer Vision}, | ||
pages={178--196}, | ||
month={10}, | ||
year={2022}, | ||
publisher={Springer Nature Switzerland}, | ||
address={Cham}, | ||
doi={10.1007/978-3-031-19815-1_11}, | ||
url={https://doi.org/10.1007/978-3-031-19815-1_11} | ||
} | ||
``` |
Oops, something went wrong.