Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support TTA #9452

Merged
merged 22 commits into from
Feb 14, 2023
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./centernet_tta.py'
]

dataset_type = 'CocoDataset'
Expand Down
39 changes: 39 additions & 0 deletions configs/centernet/centernet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
tta_model = dict(
type='DetTTAModel',
zytx121 marked this conversation as resolved.
Show resolved Hide resolved
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))

tta_pipeline = [
dict(
type='LoadImageFromFile',
to_float32=True,
file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
# ``RandomFlip`` must be placed before ``RandomCenterCropPad``,
# otherwise bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='RandomCenterCropPad',
ratios=None,
border=None,
mean=[0, 0, 0],
std=[1, 1, 1],
to_rgb=True,
test_mode=True,
test_pad_mode=['logical_or', 31],
test_pad_add_pix=1),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'flip', 'flip_direction', 'border'))
]
])
]
3 changes: 2 additions & 1 deletion configs/retinanet/retinanet_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./retinanet_tta.py'
]

# optimizer
Expand Down
23 changes: 23 additions & 0 deletions configs/retinanet/retinanet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))

img_scales = [(1333, 800), (666, 400), (2000, 1200)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=s, keep_ratio=True) for s in img_scales
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]])
]
2 changes: 1 addition & 1 deletion configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py',
'../_base_/datasets/coco_detection.py'
'../_base_/datasets/coco_detection.py', './rtmdet_tta.py'
]
model = dict(
type='RTMDet',
Expand Down
35 changes: 35 additions & 0 deletions configs/rtmdet/rtmdet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))

img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=s, keep_ratio=True)
for s in img_scales
],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='Pad',
size=(960, 960),
pad_val=dict(img=(114, 114, 114))),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
5 changes: 4 additions & 1 deletion configs/yolox/yolox_s_8xb8-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']
_base_ = [
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./yolox_tta.py'
]

img_scale = (640, 640) # width, height

Expand Down
35 changes: 35 additions & 0 deletions configs/yolox/yolox_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.65), max_per_img=100))

img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=s, keep_ratio=True)
for s in img_scales
],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
73 changes: 73 additions & 0 deletions docs/en/user_guides/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,76 @@ data = dict(train_dataloader=dict(...), val_dataloader=dict(...), test_dataloade
```

Or you can set it through `--cfg-options` as `--cfg-options test_dataloader.batch_size=2`

## Test Time Augmentation (TTA)

Test time augmentation (TTA) is a data augmentation strategy used during the testing phase. It involves applying various augmentations, such as flipping and scaling, to the same image and then merging the predictions of each augmented image to produce a more accurate prediction. To make it easier for users to use TTA, MMEngine provides [BaseTTAModel](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.model.BaseTTAModel.html#mmengine.model.BaseTTAModel) class, which allows users to implement different TTA strategies by simply extending the BaseTTAModel class according to their needs.
RangiLyu marked this conversation as resolved.
Show resolved Hide resolved

In MMDetection, we provides [DetTTAModel](../../../mmdet/models/test_time_augs/det_tta.py) class, which inherits from BaseTTAModel.

You can simplely run:

```shell
# Single-gpu testing
python tools/test.py \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
[--tta]

# CPU: disable GPUs and run single-gpu testing script
export CUDA_VISIBLE_DEVICES=-1
python tools/test.py \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
[--out ${RESULT_FILE}] \
[--tta]

# Multi-gpu testing
bash tools/dist_test.sh \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
${GPU_NUM} \
[--tta]
```

By default, we only use 2 flipping enhancements (flipping and not flipping).
You can also modify the config of TTA by yourself, such as adding scaling enhancement:

```shell
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(
type='nms',
iou_threshold=0.5),
max_per_img=100))

img_scales = [(1333, 800), (666, 400), (2000, 1200)]
tta_pipeline = [
dict(type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=s, keep_ratio=True) for s in img_scales
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
], [
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]])]
```

The above data augmentation pipeline will first perform 3 multi-scaling enhancements on the image, followed by 2 flipping enhancements (flipping and not flipping). Finally, the image is packaged into the final result using PackDetInputs.
zytx121 marked this conversation as resolved.
Show resolved Hide resolved

Here are some TTA configs for your reference:

- [RetinaNet](../../../configs/_base_/tta/retinanet_tta.py)
- [CenterNet](../../../configs/_base_/tta/centernet_tta.py)
- [YOLOX](../../../configs/_base_/tta/rtmdet_tta_.py)
- [RTMDet](../../../configs/_base_/tta/yolox_tta.py)

For more advanced usage and data flow of TTA, please refer to [MMEngine](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/test_time_augmentation.html#data-flow). We will support instance segmentation TTA latter.
6 changes: 3 additions & 3 deletions mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmengine.logging import MessageHub
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmengine.structures import PixelData
from mmengine.utils import is_list_of
from mmengine.utils import is_seq_of
from torch import Tensor

from mmdet.models.utils import unfold_wo_center
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
pad_size_divisor."""
_batch_inputs = data['inputs']
# Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor):
if is_seq_of(_batch_inputs, torch.Tensor):
batch_pad_shape = []
for ori_input in _batch_inputs:
pad_h = int(
Expand All @@ -173,7 +173,7 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
self.pad_size_divisor)) * self.pad_size_divisor
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
else:
raise TypeError('Output of `cast_data` should be a list of dict '
raise TypeError('Output of `cast_data` should be a dict '
'or a tuple with inputs and data_samples, but got'
f'{type(data)}: {data}')
return batch_pad_shape
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/test_time_augs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .det_tta import DetTTAModel
from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_results,
merge_aug_scores)

__all__ = [
'merge_aug_bboxes', 'merge_aug_masks', 'merge_aug_proposals',
'merge_aug_scores', 'merge_aug_results'
'merge_aug_scores', 'merge_aug_results', 'DetTTAModel'
]
Loading