Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor transformer modules #618

Merged
merged 24 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/_base_/recog_models/nrtr_modality_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
model = dict(
type='NRTR',
backbone=dict(type='NRTRModalityTransform'),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder', n_layers=12),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)
31 changes: 12 additions & 19 deletions configs/_base_/recog_pipelines/nrtr_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,19 @@
'filename', 'ori_shape', 'resize_shape', 'text', 'valid_ratio'
]),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'resize_shape', 'valid_ratio'
]),
])
type='ResizeOCR',
height=32,
min_width=32,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=['filename', 'ori_shape', 'resize_shape', 'valid_ratio'])
]
14 changes: 10 additions & 4 deletions configs/textrecog/nrtr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,16 @@ Backbone
| Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
| :-------------------------------------------------------------: | :----------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 93.9 | 90.0 | 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 94.7 | 87.5 | 93.3 | | 75.1 | 78.9 | 87.9 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_160845.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py) | R31-1/16-1/8 | 94.7 | 87.3 | 94.3 | | 73.5 | 78.9 | 85.1 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211124_002420.log.json) |
| [NRTR](/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py) | R31-1/8-1/4 | 95.2 | 90.0 | 94.0 | | 74.1 | 79.4 | 88.2 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20211123_232151.log.json) |

**Notes:**

- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
- For backbone `R31-1/16-1/8`:
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
- The encoder-block number is 6.
- `1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
- For backbone `R31-1/8-1/4`:
- The output consists of 92 classes, including 26 lowercase letters, 26 uppercase letters, 28 symbols, 10 digital numbers, 1 unknown token and 1 end-of-sequence token.
- The encoder-block number is 6.
- `1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
40 changes: 20 additions & 20 deletions configs/textrecog/nrtr/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ Collections:
Training Data: OCRDataset
Training Techniques:
- Adam
Epochs: 5
Batch Size: 8192
Training Resources: 64x GeForce GTX 1080 Ti
Epochs: 6
Batch Size: 6144
Training Resources: 48x GeForce GTX 1080 Ti
Architecture:
- ResNet31OCR
- TFEncoder
- TFDecoder
- CNN
- NRTREncoder
- NRTRDecoder
Paper:
URL: https://arxiv.org/pdf/1806.00926.pdf
Title: 'NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition'
Expand All @@ -28,28 +28,28 @@ Models:
- Task: Text Recognition
Dataset: IIIT5K
Metrics:
word_acc: 93.9
word_acc: 94.7
- Task: Text Recognition
Dataset: SVT
Metrics:
word_acc: 80.0
word_acc: 87.3
- Task: Text Recognition
Dataset: ICDAR2013
Metrics:
word_acc: 93.5
word_acc: 94.3
- Task: Text Recognition
Dataset: ICDAR2015
Metrics:
word_acc: 74.5
word_acc: 73.5
- Task: Text Recognition
Dataset: SVTP
Metrics:
word_acc: 78.5
word_acc: 78.9
- Task: Text Recognition
Dataset: CT80
Metrics:
word_acc: 86.5
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth
word_acc: 85.1
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth

- Name: nrtr_r31_1by8_1by4_academic
In Collection: NRTR
Expand All @@ -62,25 +62,25 @@ Models:
- Task: Text Recognition
Dataset: IIIT5K
Metrics:
word_acc: 94.7
word_acc: 95.2
- Task: Text Recognition
Dataset: SVT
Metrics:
word_acc: 87.5
word_acc: 90.0
- Task: Text Recognition
Dataset: ICDAR2013
Metrics:
word_acc: 93.3
word_acc: 94.0
- Task: Text Recognition
Dataset: ICDAR2015
Metrics:
word_acc: 75.1
word_acc: 74.1
- Task: Text Recognition
Dataset: SVTP
Metrics:
word_acc: 78.9
word_acc: 79.4
- Task: Text Recognition
Dataset: CT80
Metrics:
word_acc: 87.9
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth
word_acc: 88.2
Weights: https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth
32 changes: 32 additions & 0 deletions configs/textrecog/nrtr/nrtr_modality_transform_academic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/nrtr_modality_transform.py',
'../../_base_/schedules/schedule_adam_step_6e.py',
'../../_base_/recog_datasets/ST_MJ_train.py',
'../../_base_/recog_datasets/academic_test.py',
'../../_base_/recog_pipelines/nrtr_pipeline.py'
]

train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}

train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}

data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))

evaluation = dict(interval=1, metric='acc')
31 changes: 31 additions & 0 deletions configs/textrecog/nrtr/nrtr_modality_transform_toy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_base_ = [
'../../_base_/default_runtime.py',
'../../_base_/recog_models/nrtr_modality_transform.py',
'../../_base_/schedules/schedule_adam_step_6e.py',
'../../_base_/recog_datasets/toy_data.py',
'../../_base_/recog_pipelines/nrtr_pipeline.py'
]

train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}

train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}

data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(
type='UniformConcatDataset',
datasets=train_list,
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset',
datasets=test_list,
pipeline=test_pipeline))

evaluation = dict(interval=1, metric='acc')
6 changes: 2 additions & 4 deletions configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,15 @@
channels=[32, 64, 128, 256, 512, 512],
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
last_stage_pool=True),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder'),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)

data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
Expand Down
6 changes: 2 additions & 4 deletions configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,15 @@
channels=[32, 64, 128, 256, 512, 512],
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
last_stage_pool=False),
encoder=dict(type='TFEncoder'),
decoder=dict(type='TFDecoder'),
encoder=dict(type='NRTREncoder'),
decoder=dict(type='NRTRDecoder'),
loss=dict(type='TFLoss'),
label_convertor=label_convertor,
max_seq_len=40)

data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='UniformConcatDataset',
datasets=train_list,
Expand Down
2 changes: 1 addition & 1 deletion configs/textrecog/satrn/satrn_academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
d_inner=512 * 4,
dropout=0.1),
decoder=dict(
type='TFDecoder',
type='NRTRDecoder',
n_layers=6,
d_embedding=512,
n_head=8,
Expand Down
2 changes: 1 addition & 1 deletion configs/textrecog/satrn/satrn_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
d_inner=256 * 4,
dropout=0.1),
decoder=dict(
type='TFDecoder',
type='NRTRDecoder',
n_layers=6,
d_embedding=256,
n_head=8,
Expand Down
34 changes: 29 additions & 5 deletions mmocr/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,36 @@ def disable_text_recog_aug_test(cfg, set_types=None):
assert set_types is None or isinstance(set_types, list)
if set_types is None:
set_types = ['val', 'test']
warnings.simplefilter('once')
warning_msg = 'Remove "MultiRotateAugOCR" to support batch ' + \
'inference since samples_per_gpu > 1.'
for set_type in set_types:
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
cfg.data[set_type].pipeline = [
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
dataset_type = cfg.data[set_type].type
if dataset_type in ['OCRDataset', 'OCRSegDataset']:
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
cfg.data[set_type].pipeline = [
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
elif dataset_type in ['ConcatDataset', 'UniformConcatDataset']:
if dataset_type == 'UniformConcatDataset':
uniform_pipeline = cfg.data[set_type].pipeline
if uniform_pipeline is not None:
if uniform_pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
cfg.data[set_type].pipeline = [
uniform_pipeline[0],
*uniform_pipeline[1].transforms
]
for dataset in cfg.data[set_type].datasets:
if dataset.pipeline is not None:
if dataset.pipeline[1].type == 'MultiRotateAugOCR':
warnings.warn(warning_msg)
dataset.pipeline = [
dataset.pipeline[0],
*dataset.pipeline[1].transforms
]

return cfg

Expand Down
6 changes: 4 additions & 2 deletions mmocr/models/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import backbones, losses
from . import backbones, layers, losses, modules

from .backbones import * # NOQA
from .losses import * # NOQA
from .layers import * # NOQA
from .modules import * # NOQA

__all__ = backbones.__all__ + losses.__all__
__all__ = backbones.__all__ + losses.__all__ + layers.__all__ + modules.__all__
3 changes: 3 additions & 0 deletions mmocr/models/common/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .transformer_layers import TFDecoderLayer, TFEncoderLayer

__all__ = ['TFEncoderLayer', 'TFDecoderLayer']
Loading