-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MMSIG] Support badcase analyze in test #2584
Merged
Merged
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
cc5f248
add badcase hook
Indigo6 cc9e29e
add loss based badcase analyze
Indigo6 c621837
support accurayc based badcase analyze
Indigo6 e1eadcf
fix configdict bug
Indigo6 3496be9
revert cfg
Indigo6 bfd1707
add badcase analyze sample cfg
Indigo6 c658f87
support draw_line with str value color
Indigo6 5773b69
add unit test for badcase hook
Indigo6 8e29729
use str based color
Indigo6 4cfc7ab
rm useless codes and add warnings
Indigo6 54eecc3
move badcase hook config to default_runtime.py
Indigo6 271d105
rename badcase hook and fix linting
Indigo6 ac462c7
set and sort default cfg of badcase
Indigo6 7595b4f
update badcase or pred show logic
Indigo6 2d68e40
fix linting in test_badcase_hook.py
Indigo6 371a01f
fix rename bug
Indigo6 34aca1f
fix rename bug
Indigo6 af3f475
Update mmpose/engine/hooks/badcase_hook.py
Indigo6 8607f12
Update mmpose/engine/hooks/badcase_hook.py
Indigo6 6149d5a
fix linting
Indigo6 725c1a1
bgr2rgb after mmcv.color_val
Indigo6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
...ody_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
_base_ = ['../../../_base_/default_runtime.py'] | ||
|
||
# runtime | ||
train_cfg = dict(max_epochs=210, val_interval=10) | ||
|
||
# optimizer | ||
optim_wrapper = dict(optimizer=dict( | ||
type='Adam', | ||
lr=5e-4, | ||
)) | ||
|
||
# learning policy | ||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', begin=0, end=500, start_factor=0.001, | ||
by_epoch=False), # warm-up | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=210, | ||
milestones=[170, 200], | ||
gamma=0.1, | ||
by_epoch=True) | ||
] | ||
|
||
# automatically scaling LR based on the actual training batch size | ||
auto_scale_lr = dict(base_batch_size=512) | ||
|
||
# hooks | ||
default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater'), | ||
badcase=dict(type="BadCaseAnalyzeHook", | ||
# metric_type="loss", | ||
metric_type="accuracy", | ||
show=True, | ||
badcase_thr=100, | ||
out_dir='badcase')) | ||
|
||
# codec settings | ||
codec = dict( | ||
type='MSRAHeatmap', input_size=(256, 256), heatmap_size=(64, 64), sigma=2) | ||
|
||
# model settings | ||
model = dict( | ||
type='TopdownPoseEstimator', | ||
data_preprocessor=dict( | ||
type='PoseDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='MobileNetV2', | ||
widen_factor=1., | ||
out_indices=(7, ), | ||
init_cfg=dict(type='Pretrained', checkpoint='mmcls://mobilenet_v2'), | ||
), | ||
head=dict( | ||
type='HeatmapHead', | ||
in_channels=1280, | ||
out_channels=16, | ||
loss=dict(type='KeypointMSELoss', use_target_weight=True), | ||
decoder=codec), | ||
test_cfg=dict( | ||
flip_test=True, | ||
flip_mode='heatmap', | ||
shift_heatmap=True, | ||
)) | ||
|
||
# base dataset settings | ||
dataset_type = 'MpiiDataset' | ||
data_mode = 'topdown' | ||
data_root = 'data/mpii/' | ||
|
||
# pipelines | ||
train_pipeline = [ | ||
dict(type='LoadImage'), | ||
dict(type='GetBBoxCenterScale'), | ||
dict(type='RandomFlip', direction='horizontal'), | ||
dict(type='RandomBBoxTransform', shift_prob=0), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='GenerateTarget', encoder=codec), | ||
dict(type='PackPoseInputs') | ||
] | ||
val_pipeline = [ | ||
dict(type='LoadImage'), | ||
dict(type='GetBBoxCenterScale'), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='PackPoseInputs') | ||
] | ||
|
||
# data loaders | ||
train_dataloader = dict( | ||
batch_size=64, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/mpii_train.json', | ||
data_prefix=dict(img='images/'), | ||
pipeline=train_pipeline, | ||
)) | ||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_mode=data_mode, | ||
ann_file='annotations/mpii_val.json', | ||
headbox_file='data/mpii/annotations/mpii_gt_val.mat', | ||
data_prefix=dict(img='images/'), | ||
test_mode=True, | ||
pipeline=val_pipeline, | ||
)) | ||
test_dataloader = val_dataloader | ||
|
||
# evaluators | ||
val_evaluator = dict(type='MpiiPCKAccuracy') | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .ema_hook import ExpMomentumEMA | ||
from .visualization_hook import PoseVisualizationHook | ||
from .badcase_hook import BadCaseAnalyzeHook | ||
|
||
__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA'] | ||
__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA', 'BadCaseAnalyzeHook'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,222 @@ | ||||||||||||||
# Copyright (c) OpenMMLab. All rights reserved. | ||||||||||||||
import os | ||||||||||||||
import json | ||||||||||||||
import torch | ||||||||||||||
import warnings | ||||||||||||||
import numpy as np | ||||||||||||||
from typing import Optional, Sequence, Dict | ||||||||||||||
|
||||||||||||||
import mmcv | ||||||||||||||
import mmengine | ||||||||||||||
import mmengine.fileio as fileio | ||||||||||||||
from mmengine.config import ConfigDict | ||||||||||||||
from mmengine.hooks import Hook | ||||||||||||||
from mmengine.runner import Runner | ||||||||||||||
from mmengine.visualization import Visualizer | ||||||||||||||
|
||||||||||||||
from mmpose.registry import HOOKS, MODELS, METRICS | ||||||||||||||
from mmpose.structures import PoseDataSample, merge_data_samples | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@HOOKS.register_module() | ||||||||||||||
class BadCaseAnalyzeHook(Hook): | ||||||||||||||
Tau-J marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
"""Bad Case Analyze Hook. Used to visualize validation and | ||||||||||||||
testing process prediction results. | ||||||||||||||
|
||||||||||||||
In the testing phase: | ||||||||||||||
|
||||||||||||||
1. If ``show`` is True, it means that only the prediction results are | ||||||||||||||
visualized without storing data, so ``vis_backends`` needs to | ||||||||||||||
be excluded. | ||||||||||||||
2. If ``out_dir`` is specified, it means that the prediction results | ||||||||||||||
need to be saved to ``out_dir``. In order to avoid vis_backends | ||||||||||||||
also storing data, so ``vis_backends`` needs to be excluded. | ||||||||||||||
3. ``vis_backends`` takes effect if the user does not specify ``show`` | ||||||||||||||
and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or | ||||||||||||||
TensorboardVisBackend to store the prediction result in Wandb or | ||||||||||||||
Tensorboard. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
enable (bool): whether to draw prediction results. If it is False, | ||||||||||||||
it means that no drawing will be done. Defaults to False. | ||||||||||||||
show (bool): Whether to display the drawn image. Default to False. | ||||||||||||||
wait_time (float): The interval of show (s). Defaults to 0. | ||||||||||||||
interval (int): The interval of visualization. Defaults to 50. | ||||||||||||||
kpt_thr (float): The threshold to visualize the keypoints. Defaults to 0.3. | ||||||||||||||
out_dir (str, optional): directory where painted images | ||||||||||||||
will be saved in testing process. | ||||||||||||||
backend_args (dict, optional): Arguments to instantiate the preifx of | ||||||||||||||
uri corresponding backend. Defaults to None. | ||||||||||||||
metric_type (str): the mretic type to decide a badcase, loss or accuracy. | ||||||||||||||
metric (ConfigDict): The config of metric. | ||||||||||||||
metric_key (str): key of needed metric value in the return dict from class 'metric'. | ||||||||||||||
badcase_thr (float): min loss or max accuracy for a badcase. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
enable: bool = False, | ||||||||||||||
show: bool = False, | ||||||||||||||
wait_time: float = 0., | ||||||||||||||
interval: int = 50, | ||||||||||||||
kpt_thr: float = 0.3, | ||||||||||||||
out_dir: Optional[str] = None, | ||||||||||||||
backend_args: Optional[dict] = None, | ||||||||||||||
metric_type: str = 'loss', | ||||||||||||||
metric: ConfigDict = ConfigDict(type='KeypointMSELoss'), | ||||||||||||||
metric_key: str = 'PCK', | ||||||||||||||
badcase_thr: float = 5, | ||||||||||||||
): | ||||||||||||||
self._visualizer: Visualizer = Visualizer.get_current_instance() | ||||||||||||||
self.interval = interval | ||||||||||||||
self.kpt_thr = kpt_thr | ||||||||||||||
self.show = show | ||||||||||||||
if self.show: | ||||||||||||||
# No need to think about vis backends. | ||||||||||||||
self._visualizer._vis_backends = {} | ||||||||||||||
warnings.warn('The show is True, it means that only ' | ||||||||||||||
'the prediction results are visualized ' | ||||||||||||||
'without storing data, so vis_backends ' | ||||||||||||||
'needs to be excluded.') | ||||||||||||||
|
||||||||||||||
self.wait_time = wait_time | ||||||||||||||
self.enable = enable | ||||||||||||||
self.out_dir = out_dir | ||||||||||||||
self._test_index = 0 | ||||||||||||||
self.backend_args = backend_args | ||||||||||||||
|
||||||||||||||
self.metric_type = metric_type | ||||||||||||||
if metric_type not in ['loss', 'accuracy']: | ||||||||||||||
raise KeyError( | ||||||||||||||
f'The badcase metric type {metric_type} is not supported by ' | ||||||||||||||
f"{self.__class__.__name__}. Should be one of 'loss', " | ||||||||||||||
f"'accuracy', but got {metric_type}.") | ||||||||||||||
self.metric = MODELS.build(metric) if metric_type == 'loss'\ | ||||||||||||||
else METRICS.build(metric) | ||||||||||||||
self.metric_name = metric.type if metric_type == 'loss'\ | ||||||||||||||
else metric_key | ||||||||||||||
self.metric_key = metric_key | ||||||||||||||
self.badcase_thr = badcase_thr | ||||||||||||||
self.results = [] | ||||||||||||||
|
||||||||||||||
def check_badcase(self, data_batch, data_sample): | ||||||||||||||
"""Check whether the sample is a badcase | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
data_batch (Sequence[dict]): A batch of data | ||||||||||||||
from the dataloader. | ||||||||||||||
data_samples (Sequence[dict]): A batch of outputs from | ||||||||||||||
the model. | ||||||||||||||
Return: | ||||||||||||||
is_badcase (bool): whether the sample is a badcase or not | ||||||||||||||
metric_value (float) | ||||||||||||||
""" | ||||||||||||||
if self.metric_type == 'loss': | ||||||||||||||
gts = data_sample.gt_instances.keypoints | ||||||||||||||
preds = data_sample.pred_instances.keypoints | ||||||||||||||
Indigo6 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
with torch.no_grad(): | ||||||||||||||
metric_value = self.metric(torch.tensor(preds), | ||||||||||||||
torch.tensor(gts)).item() | ||||||||||||||
is_badcase = metric_value >= self.badcase_thr | ||||||||||||||
else: | ||||||||||||||
self.metric.process([data_batch], [data_sample.to_dict()]) | ||||||||||||||
metric_value = self.metric.evaluate(1)[self.metric_key] | ||||||||||||||
is_badcase = metric_value <= self.badcase_thr | ||||||||||||||
return is_badcase, metric_value | ||||||||||||||
|
||||||||||||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | ||||||||||||||
outputs: Sequence[PoseDataSample]) -> None: | ||||||||||||||
"""Run after every testing iterations. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
runner (:obj:`Runner`): The runner of the testing process. | ||||||||||||||
batch_idx (int): The index of the current batch in the test loop. | ||||||||||||||
data_batch (dict): Data from dataloader. | ||||||||||||||
outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. | ||||||||||||||
""" | ||||||||||||||
if self.enable is False: | ||||||||||||||
return | ||||||||||||||
|
||||||||||||||
if self.out_dir is not None: | ||||||||||||||
self.out_dir = os.path.join(runner.work_dir, runner.timestamp, | ||||||||||||||
self.out_dir) | ||||||||||||||
mmengine.mkdir_or_exist(self.out_dir) | ||||||||||||||
|
||||||||||||||
self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta) | ||||||||||||||
|
||||||||||||||
for data_sample in outputs: | ||||||||||||||
self._test_index += 1 | ||||||||||||||
|
||||||||||||||
img_path = data_sample.get('img_path') | ||||||||||||||
img_bytes = fileio.get(img_path, backend_args=self.backend_args) | ||||||||||||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | ||||||||||||||
data_sample = merge_data_samples([data_sample]) | ||||||||||||||
|
||||||||||||||
is_badcase, metric_value = self.check_badcase(data_batch, data_sample) | ||||||||||||||
|
||||||||||||||
if is_badcase: | ||||||||||||||
img_name, postfix = os.path.basename(img_path).rsplit( | ||||||||||||||
'.', 1) | ||||||||||||||
bboxes = data_sample.gt_instances.bboxes.astype(int).tolist() | ||||||||||||||
bbox_info = 'bbox' + str(bboxes) | ||||||||||||||
metric_postfix = self.metric_name + str(round(metric_value, 2)) | ||||||||||||||
|
||||||||||||||
self.results.append({'img': img_name, | ||||||||||||||
'bbox': bboxes, | ||||||||||||||
self.metric_name: metric_value}) | ||||||||||||||
|
||||||||||||||
badcase_name = f'{img_name}_{bbox_info}_{metric_postfix}' | ||||||||||||||
|
||||||||||||||
out_file = None | ||||||||||||||
if self.out_dir is not None: | ||||||||||||||
out_file = f'{badcase_name}.{postfix}' | ||||||||||||||
out_file = os.path.join(self.out_dir, out_file) | ||||||||||||||
|
||||||||||||||
# draw gt keypoints in blue color | ||||||||||||||
self._visualizer.kpt_color = 'blue' | ||||||||||||||
self._visualizer.link_color = 'blue' | ||||||||||||||
Comment on lines
+181
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Um …… I suggest using |
||||||||||||||
img_gt_drawn = self._visualizer.add_datasample( | ||||||||||||||
badcase_name if self.show else 'test_img', | ||||||||||||||
img, | ||||||||||||||
data_sample=data_sample, | ||||||||||||||
show=False, | ||||||||||||||
draw_pred=False, | ||||||||||||||
draw_gt=True, | ||||||||||||||
draw_bbox=False, | ||||||||||||||
draw_heatmap=False, | ||||||||||||||
wait_time=self.wait_time, | ||||||||||||||
kpt_thr=self.kpt_thr, | ||||||||||||||
out_file=None, | ||||||||||||||
step=self._test_index) | ||||||||||||||
# draw pred keypoints in red color | ||||||||||||||
self._visualizer.kpt_color = 'red' | ||||||||||||||
self._visualizer.link_color = 'red' | ||||||||||||||
self._visualizer.add_datasample( | ||||||||||||||
badcase_name if self.show else 'test_img', | ||||||||||||||
img_gt_drawn, | ||||||||||||||
data_sample=data_sample, | ||||||||||||||
show=self.show, | ||||||||||||||
draw_pred=True, | ||||||||||||||
draw_gt=False, | ||||||||||||||
draw_bbox=True, | ||||||||||||||
draw_heatmap=False, | ||||||||||||||
wait_time=self.wait_time, | ||||||||||||||
kpt_thr=self.kpt_thr, | ||||||||||||||
out_file=out_file, | ||||||||||||||
step=self._test_index) | ||||||||||||||
|
||||||||||||||
def after_test_epoch(self, | ||||||||||||||
runner, | ||||||||||||||
metrics: Optional[Dict[str, float]] = None) -> None: | ||||||||||||||
"""All subclasses should override this method, if they need any | ||||||||||||||
operations after each test epoch. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
runner (Runner): The runner of the testing process. | ||||||||||||||
metrics (Dict[str, float], optional): Evaluation results of all | ||||||||||||||
metrics on test dataset. The keys are the names of the | ||||||||||||||
metrics, and the values are corresponding results. | ||||||||||||||
""" | ||||||||||||||
out_file = os.path.join(self.out_dir, 'results.json') | ||||||||||||||
with open(out_file, 'w') as f: | ||||||||||||||
json.dump(self.results, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about moving
badcase
intodefault_runtime.py
and disable by defaults? In this way, there is no need to add such a new configThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine to me. The usage samples can be added to the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the config of
badcase
in experiment config files able to overwrite that in default_runtime.py?