Skip to content

Commit

Permalink
[Refactor] Use mmeval.MeanIoU for SegMetric (#1929)
Browse files Browse the repository at this point in the history
* Use mmeval.MeanIoU

* fix comments and add self.reset
  • Loading branch information
ZCMax authored Oct 19, 2022
1 parent b8775d8 commit c4bc2f3
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 127 deletions.
199 changes: 88 additions & 111 deletions mmdet3d/evaluation/metrics/seg_metric.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from typing import Dict, Optional, Sequence
import warnings
from typing import Sequence

import mmcv
import numpy as np
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmengine.logging import print_log
from mmeval.metrics import MeanIoU
from terminaltables import AsciiTable

from mmdet3d.evaluation import seg_eval
from mmdet3d.registry import METRICS


@METRICS.register_module()
class SegMetric(BaseMetric):
"""3D semantic segmentation evaluation metric.
class SegMetric(MeanIoU):
"""A wrapper of ``mmeval.MeanIoU`` for 3D semantic segmentation.
This wrapper implements the `process` method that parses predictions and
labels from inputs. This enables ``mmengine.Evaluator`` to handle the data
flow of different tasks through a unified interface.
In addition, this wrapper also implements the ``evaluate`` method that
parses metric results and print pretty table of metrics per class.
Args:
collect_device (str, optional): Device name used for collecting
results from different ranks during distributed training.
Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
prefix (str): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None.
pklfile_prefix (str, optional): The prefix of pkl files, including
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str, optional): The prefix of submission data.
If not specified, the submission data will not be generated.
Default: None.
dist_backend (str | None): The name of the distributed communication
backend. Refer to :class:`mmeval.BaseMetric`.
Defaults to 'torch_cuda'.
**kwargs: Keyword parameters passed to :class:`mmeval.MeanIoU`.
"""

def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
pklfile_prefix: str = None,
submission_prefix: str = None,
**kwargs):
self.pklfile_prefix = pklfile_prefix
self.submission_prefix = submission_prefix
super(SegMetric, self).__init__(
prefix=prefix, collect_device=collect_device)
def __init__(self, dist_backend='torch_cpu', **kwargs):
iou_metrics = kwargs.pop('iou_metrics', None)
if iou_metrics is not None:
warnings.warn(
'DeprecationWarning: The `iou_metrics` parameter of '
'`IoUMetric` is deprecated, defaults return all metrics now!')
collect_device = kwargs.pop('collect_device', None)

if collect_device is not None:
warnings.warn(
'DeprecationWarning: The `collect_device` parameter of '
'`IoUMetric` is deprecated, use `dist_backend` instead.')

# Changes the default value of `classwise_results` to True.
super().__init__(
classwise_results=True, dist_backend=dist_backend, **kwargs)

def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions.
Expand All @@ -55,83 +55,60 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""
predictions, labels = [], []
for data_sample in data_samples:
pred_3d = data_sample['pred_pts_seg']
eval_ann_info = data_sample['eval_ann_info']
cpu_pred_3d = dict()
for k, v in pred_3d.items():
if hasattr(v, 'to'):
cpu_pred_3d[k] = v.to('cpu').numpy()
else:
cpu_pred_3d[k] = v
self.results.append((eval_ann_info, cpu_pred_3d))

def format_results(self, results):
r"""Format the results to txt file. Refer to `ScanNet documentation
<http://kaldir.vc.in.tum.de/scannet_benchmark/documentation>`_.
Args:
outputs (list[dict]): Testing results of the dataset.
Returns:
tuple: (outputs, tmp_dir), outputs is the detection results,
tmp_dir is the temporal directory created for saving submission
files when ``submission_prefix`` is not specified.
# (num_points, ) -> (num_points, 1)
pred = data_sample['pred_pts_seg']['pts_semantic_mask'].unsqueeze(
-1)
label = data_sample['gt_pts_seg']['pts_semantic_mask'].unsqueeze(
-1)
predictions.append(pred)
labels.append(label)
self.add(predictions, labels)

def evaluate(self, *args, **kwargs):
"""Returns metric results and print pretty table of metrics per class.
This method would be invoked by ``mmengine.Evaluator``.
"""

submission_prefix = self.submission_prefix
if submission_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
submission_prefix = osp.join(tmp_dir.name, 'results')
mmcv.mkdir_or_exist(submission_prefix)
ignore_index = self.dataset_meta['ignore_index']
# need to map network output to original label idx
cat2label = np.zeros(len(self.dataset_meta['label2cat'])).astype(
np.int)
for original_label, output_idx in self.dataset_meta['label2cat'].items(
):
if output_idx != ignore_index:
cat2label[output_idx] = original_label

for i, (eval_ann, result) in enumerate(results):
sample_idx = eval_ann['point_cloud']['lidar_idx']
pred_sem_mask = result['semantic_mask'].numpy().astype(np.int)
pred_label = cat2label[pred_sem_mask]
curr_file = f'{submission_prefix}/{sample_idx}.txt'
np.savetxt(curr_file, pred_label, fmt='%d')

def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()

if self.submission_prefix:
self.format_results(results)
return None

label2cat = self.dataset_meta['label2cat']
ignore_index = self.dataset_meta['ignore_index']

gt_semantic_masks = []
pred_semantic_masks = []

for eval_ann, sinlge_pred_results in results:
gt_semantic_masks.append(eval_ann['pts_semantic_mask'])
pred_semantic_masks.append(
sinlge_pred_results['pts_semantic_mask'])

ret_dict = seg_eval(
gt_semantic_masks,
pred_semantic_masks,
label2cat,
ignore_index,
logger=logger)

return ret_dict
metric_results = self.compute(*args, **kwargs)
self.reset()

classwise_results = metric_results['classwise_results']
del metric_results['classwise_results']

# Ascii table of the metric results per class.
header = ['Class']
header += classwise_results.keys()
classes = self.dataset_meta['classes']
table_data = [header]
for i in range(self.num_classes):
row_data = [classes[i]]
for _, value in classwise_results.items():
row_data.append(f'{value[i]*100:.2f}')
table_data.append(row_data)

table = AsciiTable(table_data)
print_log('per class results:', logger='current')
print_log('\n' + table.table, logger='current')

# Ascii table of the metric results overall.
header = ['Class']
header += metric_results.keys()

table_data = [header]
row_data = ['results']
for _, value in metric_results.items():
row_data.append(f'{value*100:.2f}')
table_data.append(row_data)
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('overall results:', logger='current')
print_log('\n' + table.table, logger='current')

# Multiply value by 100 to convert to percentage and rounding.
evaluate_results = {
k: round(v * 100, 2)
for k, v in metric_results.items()
}
return evaluate_results
24 changes: 8 additions & 16 deletions tests/test_evaluation/test_metrics/test_seg_metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import numpy as np
import torch
from mmengine.structures import BaseDataElement

Expand All @@ -13,19 +12,18 @@ class TestSegMetric(unittest.TestCase):

def _demo_mm_model_output(self):
"""Create a superset of inputs needed to run test or train batches."""
pred_pts_semantic_mask = torch.Tensor([
pred_pts_semantic_mask = torch.LongTensor([
0, 0, 1, 0, 0, 2, 1, 3, 1, 2, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3, 3
])
pred_pts_seg_data = dict(pts_semantic_mask=pred_pts_semantic_mask)
data_sample = Det3DDataSample()
data_sample.pred_pts_seg = PointData(**pred_pts_seg_data)

gt_pts_semantic_mask = np.array([
0, 0, 0, 255, 0, 0, 1, 1, 1, 255, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3,
3, 255
])
ann_info_data = dict(pts_semantic_mask=gt_pts_semantic_mask)
data_sample.eval_ann_info = ann_info_data
gt_pts_semantic_mask = torch.LongTensor(([
0, 0, 0, 4, 0, 0, 1, 1, 1, 4, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4
]))
gt_pts_seg_data = dict(pts_semantic_mask=gt_pts_semantic_mask)
data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)

batch_data_samples = [data_sample]

Expand All @@ -40,14 +38,8 @@ def _demo_mm_model_output(self):
def test_evaluate(self):
data_batch = {}
predictions = self._demo_mm_model_output()
label2cat = {
0: 'car',
1: 'bicycle',
2: 'motorcycle',
3: 'truck',
}
dataset_meta = dict(label2cat=label2cat, ignore_index=255)
seg_metric = SegMetric()
dataset_meta = dict(classes=('car', 'bicyle', 'motorcycle', 'truck'))
seg_metric = SegMetric(ignore_index=len(dataset_meta['classes']))
seg_metric.dataset_meta = dataset_meta
seg_metric.process(data_batch, predictions)
res = seg_metric.evaluate(1)
Expand Down

0 comments on commit c4bc2f3

Please sign in to comment.