Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor tps config #135

Merged
merged 3 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion configs/textrecog/crnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|

## Results and models

| methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | - | - | - | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |
| [CRNN](/configs/textrecog/crnn/crnn_academic_dataset.py) | 80.5 | 81.5 | 86.5 | | 54.1 | 59.1 | 55.6 | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |
39 changes: 31 additions & 8 deletions configs/textrecog/crnn/crnn_academic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@
test_mode=False)

test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'

test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file3 = test_prefix + 'svt/test_label.txt'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'

test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'

test1 = dict(
type=dataset_type,
Expand All @@ -126,12 +133,28 @@
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3

test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4

test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5

test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6

data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))

evaluation = dict(interval=1, metric='acc')

Expand Down
49 changes: 37 additions & 12 deletions configs/textrecog/tps/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# Thin-Plate-Spline (TPS) transformation
# CRNN with TPS based STN

## Introduction

[ALGORITHM]

```bibtex
@article{shi2016end,
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
journal={IEEE transactions on pattern analysis and machine intelligence},
year={2016}
}
```

[PREPROCESSOR]

```bibtex
@article{shi2016robust,
title={Robust Scene Text Recognition with Automatic Rectification},
Expand All @@ -13,14 +24,28 @@
}
```

## About using TPS in other models

- Simply change `cfg.model.preprocessor` from `None` to
```python
dict(
type='TPSPreprocessor',
num_fiducial=20,
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1
)
## Results and Models

### Train Dataset

| trainset | instance_num | repeat_num | note |
| :------: | :----------: | :--------: | :---: |
| Syn90k | 8919273 | 1 | synth |

### Test Dataset

| testset | instance_num | note |
| :-----: | :----------: | :-----: |
| IIIT5K | 3000 | regular |
| SVT | 647 | regular |
| IC13 | 1015 | regular |
| IC15 | 2077 |irregular|
| SVTP | 645 |irregular|
| CT80 | 288 |irregular|

## Results and models

| methods | | Regular Text | | | | Irregular Text | | download |
| :------------------------------------------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
| [CRNN-STN](/configs/textrecog/tps/crnn_tps_academic_dataset.py) | 80.8 | 81.3 | 85.0 | | 59.6 | 68.1 | 53.8 | [model](https://download.openmmlab.com/mmocr/textrecog/tps/crnn_tps_academic_dataset_20210510-d221a905.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/tps/20210510_204353.log.json) |
47 changes: 35 additions & 12 deletions configs/textrecog/tps/crnn_tps_academic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1),
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here is a missing item in the previous refactor. Are there any more items missed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in a separate pr

encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),
Expand Down Expand Up @@ -68,9 +68,9 @@
dict(
type='ResizeOCR',
height=32,
min_width=4,
max_width=None,
keep_aspect_ratio=True),
min_width=32,
max_width=100,
keep_aspect_ratio=False),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
Expand Down Expand Up @@ -100,13 +100,20 @@
test_mode=False)

test_prefix = 'data/mixture/'
test_img_prefix1 = test_prefix + 'icdar_2013/'
test_img_prefix2 = test_prefix + 'IIIT5K/'
test_img_prefix3 = test_prefix + 'svt/'

test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file3 = test_prefix + 'svt/test_label.txt'
test_img_prefix1 = test_prefix + 'IIIT5K/'
test_img_prefix2 = test_prefix + 'svt/'
test_img_prefix3 = test_prefix + 'icdar_2013/'
test_img_prefix4 = test_prefix + 'icdar_2015/'
test_img_prefix5 = test_prefix + 'svtp/'
test_img_prefix6 = test_prefix + 'ct80/'

test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
test_ann_file2 = test_prefix + 'svt/test_label.txt'
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
test_ann_file6 = test_prefix + 'ct80/test_label.txt'

test1 = dict(
type=dataset_type,
Expand All @@ -131,12 +138,28 @@
test3['img_prefix'] = test_img_prefix3
test3['ann_file'] = test_ann_file3

test4 = {key: value for key, value in test1.items()}
test4['img_prefix'] = test_img_prefix4
test4['ann_file'] = test_ann_file4

test5 = {key: value for key, value in test1.items()}
test5['img_prefix'] = test_img_prefix5
test5['ann_file'] = test_ann_file5

test6 = {key: value for key, value in test1.items()}
test6['img_prefix'] = test_img_prefix6
test6['ann_file'] = test_ann_file6

data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
val=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]),
test=dict(
type='ConcatDataset',
datasets=[test1, test2, test3, test4, test5, test6]))

evaluation = dict(interval=1, metric='acc')

Expand Down
63 changes: 40 additions & 23 deletions mmocr/models/textrecog/preprocessor/tps_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,35 @@

@PREPROCESSOR.register_module()
class TPSPreprocessor(BasePreprocessor):
"""Rectification Network of RARE, namely TPS based STN."""
"""Rectification Network of RARE, namely TPS based STN in.

<https://arxiv.org/pdf/1603.03915.pdf>`_.

Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
img_size (tuple(int, int)): Size (height, width) of the input image.
rectified_img_size (tuple(int, int))::
Size (height, width) of the rectified image.
num_img_channel (int): Number of channels of the input image.

Output:
batch_rectified_img: Rectified image with size
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""

def __init__(self,
num_fiducial,
img_size,
rectified_img_size,
num_fiducial=20,
img_size=(32, 100),
rectified_img_size=(32, 100),
num_img_channel=1):
""" Based on RARE TPS
Args:
num_fiducial (int): number of fiducial points of TPS-STN
img_size (int, int): (height, width) of the input image
rectified_img_size (int, int):
(height, width) of the rectified image
num_img_channel (int): the number of channels of the input image
output:
batch_rectified_img: rectified image
[batch_size x num_img_channel x rectified_img_height
x rectified_img_width]
"""
super().__init__()
assert isinstance(num_fiducial, int)
assert num_fiducial > 0
assert isinstance(img_size, tuple)
assert isinstance(rectified_img_size, tuple)
assert isinstance(num_img_channel, int)

self.num_fiducial = num_fiducial
self.img_size = img_size
self.rectified_img_size = rectified_img_size
Expand Down Expand Up @@ -71,13 +80,15 @@ def forward(self, batch_img):

return batch_rectified_img

def init_weights(self):
pass


class LocalizationNetwork(nn.Module):
"""Localization Network of RARE, which predicts C' (K x 2) from input
(img_width x img_height)"""
(img_width x img_height)

Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
num_img_channel (int): Number of channels of the input image.
"""

def __init__(self, num_fiducial, num_img_channel):
super().__init__()
Expand Down Expand Up @@ -128,7 +139,8 @@ def forward(self, batch_img):
Args:
batch_img (tensor): Batch Input Image
[batch_size x num_img_channel x img_height x img_width]
output:

Output:
batch_C_prime : Predicted coordinates of fiducial points for
input batch [batch_size x num_fiducial x 2]
"""
Expand All @@ -141,8 +153,13 @@ def forward(self, batch_img):


class GridGenerator(nn.Module):
"""Grid Generator of RARE, which produces P_prime by multipling T with
P."""
"""Grid Generator of RARE, which produces P_prime by multipling T with P.

Args:
num_fiducial (int): Number of fiducial points of TPS-STN.
rectified_img_size (tuple(int, int)):
Size (height, width) of the rectified image.
"""

def __init__(self, num_fiducial, rectified_img_size):
"""Generate P_hat and inv_delta_C for later."""
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models/test_ocr_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest
import torch

from mmocr.models.textrecog.preprocessor import (BasePreprocessor,
TPSPreprocessor)


def test_tps_preprocessor():
with pytest.raises(AssertionError):
TPSPreprocessor(num_fiducial=-1)
with pytest.raises(AssertionError):
TPSPreprocessor(img_size=32)
with pytest.raises(AssertionError):
TPSPreprocessor(rectified_img_size=100)
with pytest.raises(AssertionError):
TPSPreprocessor(num_img_channel='bgr')

tps_preprocessor = TPSPreprocessor(
num_fiducial=20,
img_size=(32, 100),
Expand Down