Skip to content

Commit

Permalink
merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhen17 committed Mar 20, 2021
2 parents 31133e1 + f0ba0ce commit 03db9d4
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Support methods
- [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md)
- [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md)
- [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md)
- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md)
- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md)
- [x] [SSN (ECCV'2020)](configs/ssn/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net |
Expand Down
2 changes: 1 addition & 1 deletion README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代
- [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md)
- [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md)
- [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md)
- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md)
- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md)
- [x] [SSN (ECCV'2020)](configs/ssn/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net |
Expand Down
8 changes: 4 additions & 4 deletions configs/centerpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ We follow the below style to name config files. Contributors are advised to foll

`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively.
```
@article{yin2020center,
title={Center-based 3d object detection and tracking},
@article{yin2021center,
title={Center-based 3D Object Detection and Tracking},
author={Yin, Tianwei and Zhou, Xingyi and Kr{\"a}henb{\"u}hl, Philipp},
journal={arXiv preprint arXiv:2006.11275},
year={2020}
journal={CVPR},
year={2021},
}
```

Expand Down
6 changes: 5 additions & 1 deletion mmdet3d/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .indoor_eval import indoor_eval
from .kitti_utils import kitti_eval, kitti_eval_coco_style
from .lyft_eval import lyft_eval
from .seg_eval import seg_eval

__all__ = ['kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval']
__all__ = [
'kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval',
'seg_eval'
]
121 changes: 121 additions & 0 deletions mmdet3d/core/evaluation/seg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable


def fast_hist(preds, labels, num_classes):
"""Compute the confusion matrix for every batch.
Args:
preds (np.ndarray): Prediction labels of points with shape of
(num_points, ).
labels (np.ndarray): Ground truth labels of points with shape of
(num_points, ).
num_classes (int): number of classes
Returns:
np.ndarray: Calculated confusion matrix.
"""

k = (labels >= 0) & (labels < num_classes)
bin_count = np.bincount(
num_classes * labels[k].astype(int) + preds[k],
minlength=num_classes**2)
return bin_count[:num_classes**2].reshape(num_classes, num_classes)


def per_class_iou(hist):
"""Compute the per class iou.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
np.ndarray: Calculated per class iou
"""

return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))


def get_acc(hist):
"""Compute the overall accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated overall acc
"""

return np.diag(hist).sum() / hist.sum()


def get_acc_cls(hist):
"""Compute the class average accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated class average acc
"""

return np.nanmean(np.diag(hist) / hist.sum(axis=1))


def seg_eval(gt_labels, seg_preds, label2cat, logger=None):
"""Semantic Segmentation Evaluation.
Evaluate the result of the Semantic Segmentation.
Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predtictions
label2cat (dict): Map from label to category.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
Return:
dict[str, float]: Dict of results.
"""
assert len(seg_preds) == len(gt_labels)

hist_list = []
for i in range(len(seg_preds)):
hist_list.append(
fast_hist(seg_preds[i].numpy().astype(int),
gt_labels[i].numpy().astype(int), len(label2cat)))
iou = per_class_iou(sum(hist_list))
miou = np.nanmean(iou)
acc = get_acc(sum(hist_list))
acc_cls = get_acc_cls(sum(hist_list))

header = ['classes']
for i in range(len(label2cat)):
header.append(label2cat[i])
header.extend(['miou', 'acc', 'acc_cls'])

ret_dict = dict()
table_columns = [['results']]
for i in range(len(label2cat)):
ret_dict[label2cat[i]] = float(iou[i])
table_columns.append([f'{iou[i]:.4f}'])
ret_dict['miou'] = float(miou)
ret_dict['acc'] = float(acc)
ret_dict['acc_cls'] = float(acc_cls)

table_columns.append([f'{miou:.4f}'])
table_columns.append([f'{acc:.4f}'])
table_columns.append([f'{acc_cls:.4f}'])

table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)

return ret_dict
35 changes: 35 additions & 0 deletions tests/test_metrics/test_seg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pytest
import torch

from mmdet3d.core.evaluation.seg_eval import seg_eval


def test_indoor_eval():
if not torch.cuda.is_available():
pytest.skip()
seg_preds = [
torch.Tensor(
[0, 0, 1, 0, 2, 1, 3, 1, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3])
]
gt_labels = [
torch.Tensor(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
]

label2cat = {
0: 'car',
1: 'bicycle',
2: 'motorcycle',
3: 'truck',
}
ret_value = seg_eval(gt_labels, seg_preds, label2cat)

assert np.isclose(ret_value['car'], 0.428571429)
assert np.isclose(ret_value['bicycle'], 0.428571429)
assert np.isclose(ret_value['motorcycle'], 0.6666667)
assert np.isclose(ret_value['truck'], 0.6666667)

assert np.isclose(ret_value['acc'], 0.7)
assert np.isclose(ret_value['acc_cls'], 0.7)
assert np.isclose(ret_value['miou'], 0.547619048)

0 comments on commit 03db9d4

Please sign in to comment.