-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature][MMSIG] Support badcase analyze in test (#2584)
- Loading branch information
Showing
6 changed files
with
374 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .badcase_hook import BadCaseAnalysisHook | ||
from .ema_hook import ExpMomentumEMA | ||
from .visualization_hook import PoseVisualizationHook | ||
|
||
__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA'] | ||
__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA', 'BadCaseAnalysisHook'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import json | ||
import os | ||
import warnings | ||
from typing import Dict, Optional, Sequence | ||
|
||
import mmcv | ||
import mmengine | ||
import mmengine.fileio as fileio | ||
import torch | ||
from mmengine.config import ConfigDict | ||
from mmengine.hooks import Hook | ||
from mmengine.runner import Runner | ||
from mmengine.visualization import Visualizer | ||
|
||
from mmpose.registry import HOOKS, METRICS, MODELS | ||
from mmpose.structures import PoseDataSample, merge_data_samples | ||
|
||
|
||
@HOOKS.register_module() | ||
class BadCaseAnalysisHook(Hook): | ||
"""Bad Case Analyze Hook. Used to visualize validation and testing process | ||
prediction results. | ||
In the testing phase: | ||
1. If ``show`` is True, it means that only the prediction results are | ||
visualized without storing data, so ``vis_backends`` needs to | ||
be excluded. | ||
2. If ``out_dir`` is specified, it means that the prediction results | ||
need to be saved to ``out_dir``. In order to avoid vis_backends | ||
also storing data, so ``vis_backends`` needs to be excluded. | ||
3. ``vis_backends`` takes effect if the user does not specify ``show`` | ||
and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or | ||
TensorboardVisBackend to store the prediction result in Wandb or | ||
Tensorboard. | ||
Args: | ||
enable (bool): whether to draw prediction results. If it is False, | ||
it means that no drawing will be done. Defaults to False. | ||
show (bool): Whether to display the drawn image. Default to False. | ||
wait_time (float): The interval of show (s). Defaults to 0. | ||
interval (int): The interval of visualization. Defaults to 50. | ||
kpt_thr (float): The threshold to visualize the keypoints. | ||
Defaults to 0.3. | ||
out_dir (str, optional): directory where painted images | ||
will be saved in testing process. | ||
backend_args (dict, optional): Arguments to instantiate the preifx of | ||
uri corresponding backend. Defaults to None. | ||
metric_type (str): the mretic type to decide a badcase, | ||
loss or accuracy. | ||
metric (ConfigDict): The config of metric. | ||
metric_key (str): key of needed metric value in the return dict | ||
from class 'metric'. | ||
badcase_thr (float): min loss or max accuracy for a badcase. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
enable: bool = False, | ||
show: bool = False, | ||
wait_time: float = 0., | ||
interval: int = 50, | ||
kpt_thr: float = 0.3, | ||
out_dir: Optional[str] = None, | ||
backend_args: Optional[dict] = None, | ||
metric_type: str = 'loss', | ||
metric: ConfigDict = ConfigDict(type='KeypointMSELoss'), | ||
metric_key: str = 'PCK', | ||
badcase_thr: float = 5, | ||
): | ||
self._visualizer: Visualizer = Visualizer.get_current_instance() | ||
self.interval = interval | ||
self.kpt_thr = kpt_thr | ||
self.show = show | ||
if self.show: | ||
# No need to think about vis backends. | ||
self._visualizer._vis_backends = {} | ||
warnings.warn('The show is True, it means that only ' | ||
'the prediction results are visualized ' | ||
'without storing data, so vis_backends ' | ||
'needs to be excluded.') | ||
|
||
self.wait_time = wait_time | ||
self.enable = enable | ||
self.out_dir = out_dir | ||
self._test_index = 0 | ||
self.backend_args = backend_args | ||
|
||
self.metric_type = metric_type | ||
if metric_type not in ['loss', 'accuracy']: | ||
raise KeyError( | ||
f'The badcase metric type {metric_type} is not supported by ' | ||
f"{self.__class__.__name__}. Should be one of 'loss', " | ||
f"'accuracy', but got {metric_type}.") | ||
self.metric = MODELS.build(metric) if metric_type == 'loss'\ | ||
else METRICS.build(metric) | ||
self.metric_name = metric.type if metric_type == 'loss'\ | ||
else metric_key | ||
self.metric_key = metric_key | ||
self.badcase_thr = badcase_thr | ||
self.results = [] | ||
|
||
def check_badcase(self, data_batch, data_sample): | ||
"""Check whether the sample is a badcase. | ||
Args: | ||
data_batch (Sequence[dict]): A batch of data | ||
from the dataloader. | ||
data_samples (Sequence[dict]): A batch of outputs from | ||
the model. | ||
Return: | ||
is_badcase (bool): whether the sample is a badcase or not | ||
metric_value (float) | ||
""" | ||
if self.metric_type == 'loss': | ||
gts = data_sample.gt_instances.keypoints | ||
preds = data_sample.pred_instances.keypoints | ||
weights = data_sample.gt_instances.keypoints_visible | ||
with torch.no_grad(): | ||
metric_value = self.metric( | ||
torch.from_numpy(preds), torch.from_numpy(gts), | ||
torch.from_numpy(weights)).item() | ||
is_badcase = metric_value >= self.badcase_thr | ||
else: | ||
self.metric.process([data_batch], [data_sample.to_dict()]) | ||
metric_value = self.metric.evaluate(1)[self.metric_key] | ||
is_badcase = metric_value <= self.badcase_thr | ||
return is_badcase, metric_value | ||
|
||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, | ||
outputs: Sequence[PoseDataSample]) -> None: | ||
"""Run after every testing iterations. | ||
Args: | ||
runner (:obj:`Runner`): The runner of the testing process. | ||
batch_idx (int): The index of the current batch in the test loop. | ||
data_batch (dict): Data from dataloader. | ||
outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. | ||
""" | ||
if self.enable is False: | ||
return | ||
|
||
if self.out_dir is not None: | ||
self.out_dir = os.path.join(runner.work_dir, runner.timestamp, | ||
self.out_dir) | ||
mmengine.mkdir_or_exist(self.out_dir) | ||
|
||
self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta) | ||
|
||
for data_sample in outputs: | ||
self._test_index += 1 | ||
|
||
img_path = data_sample.get('img_path') | ||
img_bytes = fileio.get(img_path, backend_args=self.backend_args) | ||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb') | ||
data_sample = merge_data_samples([data_sample]) | ||
|
||
is_badcase, metric_value = self.check_badcase( | ||
data_batch, data_sample) | ||
|
||
if is_badcase: | ||
img_name, postfix = os.path.basename(img_path).rsplit('.', 1) | ||
bboxes = data_sample.gt_instances.bboxes.astype(int).tolist() | ||
bbox_info = 'bbox' + str(bboxes) | ||
metric_postfix = self.metric_name + str(round(metric_value, 2)) | ||
|
||
self.results.append({ | ||
'img': img_name, | ||
'bbox': bboxes, | ||
self.metric_name: metric_value | ||
}) | ||
|
||
badcase_name = f'{img_name}_{bbox_info}_{metric_postfix}' | ||
|
||
out_file = None | ||
if self.out_dir is not None: | ||
out_file = f'{badcase_name}.{postfix}' | ||
out_file = os.path.join(self.out_dir, out_file) | ||
|
||
# draw gt keypoints in blue color | ||
self._visualizer.kpt_color = 'blue' | ||
self._visualizer.link_color = 'blue' | ||
img_gt_drawn = self._visualizer.add_datasample( | ||
badcase_name if self.show else 'test_img', | ||
img, | ||
data_sample=data_sample, | ||
show=False, | ||
draw_pred=False, | ||
draw_gt=True, | ||
draw_bbox=False, | ||
draw_heatmap=False, | ||
wait_time=self.wait_time, | ||
kpt_thr=self.kpt_thr, | ||
out_file=None, | ||
step=self._test_index) | ||
# draw pred keypoints in red color | ||
self._visualizer.kpt_color = 'red' | ||
self._visualizer.link_color = 'red' | ||
self._visualizer.add_datasample( | ||
badcase_name if self.show else 'test_img', | ||
img_gt_drawn, | ||
data_sample=data_sample, | ||
show=self.show, | ||
draw_pred=True, | ||
draw_gt=False, | ||
draw_bbox=True, | ||
draw_heatmap=False, | ||
wait_time=self.wait_time, | ||
kpt_thr=self.kpt_thr, | ||
out_file=out_file, | ||
step=self._test_index) | ||
|
||
def after_test_epoch(self, | ||
runner, | ||
metrics: Optional[Dict[str, float]] = None) -> None: | ||
"""All subclasses should override this method, if they need any | ||
operations after each test epoch. | ||
Args: | ||
runner (Runner): The runner of the testing process. | ||
metrics (Dict[str, float], optional): Evaluation results of all | ||
metrics on test dataset. The keys are the names of the | ||
metrics, and the values are corresponding results. | ||
""" | ||
out_file = os.path.join(self.out_dir, 'results.json') | ||
with open(out_file, 'w') as f: | ||
json.dump(self.results, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
import os.path as osp | ||
import shutil | ||
import time | ||
from unittest import TestCase | ||
from unittest.mock import MagicMock | ||
|
||
import numpy as np | ||
from mmengine.config import ConfigDict | ||
from mmengine.structures import InstanceData | ||
|
||
from mmpose.engine.hooks import BadCaseAnalysisHook | ||
from mmpose.structures import PoseDataSample | ||
from mmpose.visualization import PoseLocalVisualizer | ||
|
||
|
||
def _rand_poses(num_boxes, kpt_num, h, w): | ||
center = np.random.rand(num_boxes, 2) | ||
offset = np.random.rand(num_boxes, kpt_num, 2) / 2.0 | ||
|
||
pose = center[:, None, :] + offset.clip(0, 1) | ||
pose[:, :, 0] *= w | ||
pose[:, :, 1] *= h | ||
|
||
return pose | ||
|
||
|
||
class TestBadCaseHook(TestCase): | ||
|
||
def setUp(self) -> None: | ||
kpt_num = 16 | ||
PoseLocalVisualizer.get_instance('test_badcase_hook') | ||
|
||
data_sample = PoseDataSample() | ||
data_sample.set_metainfo({ | ||
'img_path': | ||
osp.join( | ||
osp.dirname(__file__), '../../data/coco/000000000785.jpg') | ||
}) | ||
self.data_batch = {'data_samples': [data_sample] * 2} | ||
|
||
pred_det_data_sample = data_sample.clone() | ||
pred_instances = InstanceData() | ||
pred_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) | ||
pred_det_data_sample.pred_instances = pred_instances | ||
|
||
gt_instances = InstanceData() | ||
gt_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) | ||
gt_instances.keypoints_visible = np.ones((1, kpt_num)) | ||
gt_instances.head_size = np.random.rand(1, 1) | ||
gt_instances.bboxes = np.random.rand(1, 4) | ||
pred_det_data_sample.gt_instances = gt_instances | ||
self.outputs = [pred_det_data_sample] * 2 | ||
|
||
def test_after_test_iter(self): | ||
runner = MagicMock() | ||
runner.iter = 1 | ||
|
||
# test | ||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | ||
out_dir = timestamp + '1' | ||
runner.work_dir = timestamp | ||
runner.timestamp = '1' | ||
hook = BadCaseAnalysisHook(enable=False, out_dir=out_dir) | ||
hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | ||
self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}')) | ||
|
||
hook = BadCaseAnalysisHook( | ||
enable=True, | ||
out_dir=out_dir, | ||
metric_type='loss', | ||
metric=ConfigDict(type='KeypointMSELoss'), | ||
badcase_thr=-1, # is_badcase = True | ||
) | ||
hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | ||
self.assertEqual(hook._test_index, 2) | ||
self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) | ||
# same image and preds/gts, so onlu one file | ||
self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 1) | ||
|
||
hook.after_test_epoch(runner) | ||
self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}/results.json')) | ||
shutil.rmtree(f'{timestamp}') | ||
|
||
hook = BadCaseAnalysisHook( | ||
enable=True, | ||
out_dir=out_dir, | ||
metric_type='accuracy', | ||
metric=ConfigDict(type='MpiiPCKAccuracy'), | ||
badcase_thr=-1, # is_badcase = False | ||
) | ||
hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | ||
self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) | ||
self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 0) | ||
shutil.rmtree(f'{timestamp}') | ||
|
||
|
||
if __name__ == '__main__': | ||
test = TestBadCaseHook() | ||
test.setUp() | ||
test.test_after_test_iter() |
Oops, something went wrong.