Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support TTA and add --tta in tools/test.py. #1161

Merged
merged 3 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmcls/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

__all__ = [
Expand Down
4 changes: 4 additions & 0 deletions mmcls/models/tta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .score_tta import AverageClsScoreTTA

__all__ = ['AverageClsScoreTTA']
36 changes: 36 additions & 0 deletions mmcls/models/tta/score_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.model import BaseTTAModel

from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample


@MODELS.register_module()
class AverageClsScoreTTA(BaseTTAModel):

def merge_preds(
self,
data_samples_list: List[List[ClsDataSample]],
) -> List[ClsDataSample]:
"""Merge predictions of enhanced data to one prediction.

Args:
data_samples_list (List[List[ClsDataSample]]): List of predictions
of all enhanced data.

Returns:
List[ClsDataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples

def _merge_single_sample(self, data_samples):
merged_data_sample: ClsDataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_label.score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample
67 changes: 67 additions & 0 deletions tests/test_models/test_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase

import torch
from mmengine import ConfigDict

from mmcls.models import AverageClsScoreTTA, ImageClassifier
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from mmcls.utils import register_all_modules

register_all_modules()


class TestAverageClsScoreTTA(TestCase):
DEFAULT_ARGS = dict(
type='AverageClsScoreTTA',
module=dict(
type='ImageClassifier',
backbone=dict(type='ResNet', depth=18),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss'))))

def test_initialize(self):
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)
self.assertIsInstance(model.module, ImageClassifier)

def test_forward(self):
inputs = torch.rand(1, 3, 224, 224)
model: AverageClsScoreTTA = MODELS.build(self.DEFAULT_ARGS)

# The forward of TTA model should not be called.
with self.assertRaisesRegex(NotImplementedError, 'will not be called'):
model(inputs)

def test_test_step(self):
cfg = ConfigDict(deepcopy(self.DEFAULT_ARGS))
cfg.module.data_preprocessor = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5])
model: AverageClsScoreTTA = MODELS.build(cfg)

img1 = torch.randint(0, 256, (1, 3, 224, 224))
img2 = torch.randint(0, 256, (1, 3, 224, 224))
data1 = {
'inputs': img1,
'data_samples': [ClsDataSample().set_gt_label(1)]
}
data2 = {
'inputs': img2,
'data_samples': [ClsDataSample().set_gt_label(1)]
}
data_tta = {
'inputs': [img1, img2],
'data_samples': [[ClsDataSample().set_gt_label(1)],
[ClsDataSample().set_gt_label(1)]]
}

score1 = model.module.test_step(data1)[0].pred_label.score
score2 = model.module.test_step(data2)[0].pred_label.score
score_tta = model.test_step(data_tta)[0].pred_label.score

torch.testing.assert_allclose(score_tta, (score1 + score2) / 2)
29 changes: 28 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def parse_args():
'--no-pin-memory',
action='store_true',
help='whether to disable the pin_memory option in dataloaders.')
parser.add_argument(
'--tta',
action='store_true',
help='Whether to enable the Test-Time-Aug (TTA). If the config file '
'has `tta_pipeline` and `tta_model` fields, use them to determine the '
'TTA transforms and how to merge the TTA results. Otherwise, use flip '
'TTA by averaging classification score.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
Expand Down Expand Up @@ -105,7 +112,27 @@ def merge_args(cfg, args):
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]

# set dataloader args
# -------------------- TTA related args --------------------
if args.tta:
if 'tta_model' not in cfg:
cfg.tta_model = dict(type='mmcls.AverageClsScoreTTA')
if 'tta_pipeline' not in cfg:
test_pipeline = cfg.test_dataloader.dataset.pipeline
cfg.tta_pipeline = deepcopy(test_pipeline)
flip_tta = dict(
type='TestTimeAug',
transforms=[
[
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[test_pipeline[-1]],
])
cfg.tta_pipeline[-1] = flip_tta
cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline

# ----------------- Default dataloader args -----------------
default_dataloader_cfg = ConfigDict(
pin_memory=True,
collate_fn=dict(type='default_collate'),
Expand Down