Skip to content

Commit

Permalink
[Draft] Rec TTA
Browse files Browse the repository at this point in the history
  • Loading branch information
liukuikun committed Sep 22, 2022
1 parent 93d883e commit 0a19442
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 7 deletions.
2 changes: 2 additions & 0 deletions configs/textrecog/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@
type='TextRecogLocalVisualizer',
name='visualizer',
vis_backends=vis_backends)

tta_model = dict(type='EncoderDecoderRecognizerTTAModel')
57 changes: 57 additions & 0 deletions configs/textrecog/crnn/_base_crnn_mini-vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,60 @@
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]

tta_pipeline = [
dict(
type='LoadImageFromFile',
color_type='grayscale',
file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=32,
min_width=32,
max_width=None,
width_divisor=16)
],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]
55 changes: 55 additions & 0 deletions configs/textrecog/robust_scanner/_base_robustscanner_resnet31.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,58 @@
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
]

tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='TestTimeAug',
transforms=[
[
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=0, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=1, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
dict(
type='ConditionApply',
transforms=[
dict(
type='ImgAugWrapper',
args=[dict(cls='Rot90', k=3, keep_size=False)])
],
condition="results['img_shape'][1]<results['img_shape'][0]"
),
],
[
dict(
type='RescaleToHeight',
height=48,
min_width=48,
max_width=160,
width_divisor=4),
],
[dict(type='PadToWidth', width=160)],
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
[dict(type='LoadOCRAnnotations', with_text=True)],
[
dict(
type='PackTextRecogInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape',
'valid_ratio'))
]
])
]
4 changes: 2 additions & 2 deletions mmocr/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
SourceImagePad, TextDetRandomCrop,
TextDetRandomCropFlip)
from .textrecog_transforms import PadToWidth, PyramidRescale, RescaleToHeight
from .wrappers import ImgAugWrapper, TorchVisionWrapper
from .wrappers import ConditionApply, ImgAugWrapper, TorchVisionWrapper

__all__ = [
'LoadOCRAnnotations', 'RandomRotate', 'ImgAugWrapper', 'SourceImagePad',
Expand All @@ -20,5 +20,5 @@
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray'
'LoadImageFromNDArray', 'ConditionApply'
]
16 changes: 16 additions & 0 deletions mmocr/datasets/transforms/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import imgaug.augmenters as iaa
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmcv.transforms import Compose
from mmcv.transforms.base import BaseTransform
from PIL import Image

Expand Down Expand Up @@ -277,3 +278,18 @@ def __repr__(self):
repr_str += f', {k} = {v}'
repr_str += ')'
return repr_str


@TRANSFORMS.register_module()
class ConditionApply(BaseTransform):

def __init__(self, transforms, condition):
self.condition = condition
self.transforms = Compose(transforms)

def transform(self, results: Dict) -> Optional[Dict]:
"""Randomly apply the transform."""
if eval(self.condition):
return self.transforms(results) # type: ignore
else:
return results
4 changes: 3 additions & 1 deletion mmocr/models/textrecog/recognizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import BaseRecognizer
from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
from .master import MASTER
from .nrtr import NRTR
from .robust_scanner import RobustScanner
Expand All @@ -11,5 +12,6 @@

__all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
'RobustScanner', 'SATRN', 'ABINet', 'MASTER'
'RobustScanner', 'SATRN', 'ABINet', 'MASTER',
'EncoderDecoderRecognizerTTAModel'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.model import BaseTTAModel

from mmocr.registry import MODELS
from mmocr.utils.typing import RecSampleList


@MODELS.register_module()
class EncoderDecoderRecognizerTTAModel(BaseTTAModel):

def merge_preds(self,
data_samples_list: List[RecSampleList]) -> RecSampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[RecSampleList]): List of predictions of
all enhanced data.
Returns:
RecSampleList: Merged prediction.
"""
predictions = [None] * len(data_samples_list[0])
average_scores = [-1] * len(data_samples_list[0])
for data_samples in data_samples_list:
for i, data_sample in enumerate(data_samples):
score = data_sample.pred_text.score
average_score = sum(score) / max(1, len(score))
if average_score > average_scores[i]:
predictions[i] = data_sample
average_scores[i] = average_score
return predictions
5 changes: 1 addition & 4 deletions tools/analysis_tools/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args([
'configs/textdet/dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py',
'--output-dir', 'tools/analysis_tools/save', '--not-show'
])
args = parser.parse_args()
return args


Expand Down
9 changes: 9 additions & 0 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='Job launcher')
parser.add_argument(
'--tta', action='store_true', help='Test time augmentation')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
Expand Down Expand Up @@ -107,6 +109,13 @@ def main():
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)

cfg.load_from = args.checkpoint

if args.tta:
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
cfg.tta_model.module = cfg.model
cfg.model = cfg.tta_model

# save predictions
if args.save_preds:
dump_metric = dict(
Expand Down

0 comments on commit 0a19442

Please sign in to comment.