Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support semantic seg metrics #332

Merged
merged 23 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmdet3d/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .indoor_eval import indoor_eval
from .kitti_utils import kitti_eval, kitti_eval_coco_style
from .lyft_eval import lyft_eval
from .seg_eval import seg_eval

__all__ = ['kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval']
__all__ = [
'kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval',
'seg_eval'
]
121 changes: 121 additions & 0 deletions mmdet3d/core/evaluation/seg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable


def fast_hist(preds, labels, num_classes):
"""Compute the confusion matrix for every batch.

Args:
preds (np.ndarray): Prediction labels of points with shape of
(num_points, ).
labels (np.ndarray): Ground truth labels of points with shape of
(num_points, ).
num_classes (int): number of classes

Returns:
np.ndarray: Calculated confusion matrix.
"""

k = (labels >= 0) & (labels < num_classes)
bin_count = np.bincount(
num_classes * labels[k].astype(int) + preds[k],
minlength=num_classes**2)
return bin_count[:num_classes**2].reshape(num_classes, num_classes)


def per_class_iou(hist):
"""Compute the per class iou.

Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
junhaozhang98 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
np.ndarray: Calculated per class iou
"""

return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))


def get_acc(hist):
"""Compute the overall accuracy.

Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
junhaozhang98 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
float: Calculated overall acc
"""

return np.diag(hist).sum() / hist.sum()


def get_acc_cls(hist):
"""Compute the class average accuracy.

Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
junhaozhang98 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
float: Calculated class average acc
"""

return np.nanmean(np.diag(hist) / hist.sum(axis=1))


def seg_eval(gt_labels, seg_preds, label2cat, logger=None):
"""Semantic Segmentation Evaluation.

Evaluate the result of the Semantic Segmentation.

Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predtictions
label2cat (dict): Map from label to category.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.

Return:
dict[str, float]: Dict of results.
"""
assert len(seg_preds) == len(gt_labels)

hist_list = []
for i in range(len(seg_preds)):
hist_list.append(
fast_hist(seg_preds[i].numpy().astype(int),
gt_labels[i].numpy().astype(int), len(label2cat)))
iou = per_class_iou(sum(hist_list))
miou = np.nanmean(iou)
acc = get_acc(sum(hist_list))
acc_cls = get_acc_cls(sum(hist_list))

header = ['classes']
for i in range(len(label2cat)):
header.append(label2cat[i])
header.extend(['miou', 'acc', 'acc_cls'])

ret_dict = dict()
table_columns = [['results']]
for i in range(len(label2cat)):
ret_dict[label2cat[i]] = float(iou[i])
table_columns.append([f'{iou[i]:.4f}'])
ret_dict['miou'] = float(miou)
ret_dict['acc'] = float(acc)
ret_dict['acc_cls'] = float(acc_cls)

table_columns.append([f'{miou:.4f}'])
table_columns.append([f'{acc:.4f}'])
table_columns.append([f'{acc_cls:.4f}'])

table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)

return ret_dict
35 changes: 35 additions & 0 deletions tests/test_metrics/test_seg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pytest
import torch

from mmdet3d.core.evaluation.seg_eval import seg_eval


def test_indoor_eval():
if not torch.cuda.is_available():
pytest.skip()
seg_preds = [
torch.Tensor(
[0, 0, 1, 0, 2, 1, 3, 1, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3])
]
gt_labels = [
torch.Tensor(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
]

label2cat = {
0: 'car',
1: 'bicycle',
2: 'motorcycle',
3: 'truck',
}
ret_value = seg_eval(gt_labels, seg_preds, label2cat)

assert np.isclose(ret_value['car'], 0.428571429)
assert np.isclose(ret_value['bicycle'], 0.428571429)
assert np.isclose(ret_value['motorcycle'], 0.6666667)
assert np.isclose(ret_value['truck'], 0.6666667)

assert np.isclose(ret_value['acc'], 0.7)
assert np.isclose(ret_value['acc_cls'], 0.7)
assert np.isclose(ret_value['miou'], 0.547619048)