diff --git a/README.md b/README.md index f93acb0590..9ed761295e 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ Support methods - [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md) - [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md) - [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md) -- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md) +- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md) - [x] [SSN (ECCV'2020)](configs/ssn/README.md) | | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | diff --git a/README_zh-CN.md b/README_zh-CN.md index b0c4c4936e..641a9c8baf 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -83,7 +83,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代 - [x] [3DSSD (CVPR'2020)](configs/3dssd/README.md) - [x] [Part-A2 (TPAMI'2020)](configs/parta2/README.md) - [x] [MVXNet (ICRA'2019)](configs/mvxnet/README.md) -- [x] [CenterPoint (Arxiv'2020)](configs/centerpoint/README.md) +- [x] [CenterPoint (CVPR'2021)](configs/centerpoint/README.md) - [x] [SSN (ECCV'2020)](configs/ssn/README.md) | | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | diff --git a/configs/centerpoint/README.md b/configs/centerpoint/README.md index 03372ce610..ffab49d7ac 100644 --- a/configs/centerpoint/README.md +++ b/configs/centerpoint/README.md @@ -27,11 +27,11 @@ We follow the below style to name config files. Contributors are advised to foll `{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively. ``` -@article{yin2020center, - title={Center-based 3d object detection and tracking}, +@article{yin2021center, + title={Center-based 3D Object Detection and Tracking}, author={Yin, Tianwei and Zhou, Xingyi and Kr{\"a}henb{\"u}hl, Philipp}, - journal={arXiv preprint arXiv:2006.11275}, - year={2020} + journal={CVPR}, + year={2021}, } ``` diff --git a/mmdet3d/core/evaluation/__init__.py b/mmdet3d/core/evaluation/__init__.py index 0d472ca01b..f8cd210a33 100644 --- a/mmdet3d/core/evaluation/__init__.py +++ b/mmdet3d/core/evaluation/__init__.py @@ -1,5 +1,9 @@ from .indoor_eval import indoor_eval from .kitti_utils import kitti_eval, kitti_eval_coco_style from .lyft_eval import lyft_eval +from .seg_eval import seg_eval -__all__ = ['kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval'] +__all__ = [ + 'kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval', + 'seg_eval' +] diff --git a/mmdet3d/core/evaluation/seg_eval.py b/mmdet3d/core/evaluation/seg_eval.py new file mode 100644 index 0000000000..ad60e8e350 --- /dev/null +++ b/mmdet3d/core/evaluation/seg_eval.py @@ -0,0 +1,121 @@ +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable + + +def fast_hist(preds, labels, num_classes): + """Compute the confusion matrix for every batch. + + Args: + preds (np.ndarray): Prediction labels of points with shape of + (num_points, ). + labels (np.ndarray): Ground truth labels of points with shape of + (num_points, ). + num_classes (int): number of classes + + Returns: + np.ndarray: Calculated confusion matrix. + """ + + k = (labels >= 0) & (labels < num_classes) + bin_count = np.bincount( + num_classes * labels[k].astype(int) + preds[k], + minlength=num_classes**2) + return bin_count[:num_classes**2].reshape(num_classes, num_classes) + + +def per_class_iou(hist): + """Compute the per class iou. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + np.ndarray: Calculated per class iou + """ + + return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) + + +def get_acc(hist): + """Compute the overall accuracy. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + float: Calculated overall acc + """ + + return np.diag(hist).sum() / hist.sum() + + +def get_acc_cls(hist): + """Compute the class average accuracy. + + Args: + hist(np.ndarray): Overall confusion martix + (num_classes, num_classes ). + + Returns: + float: Calculated class average acc + """ + + return np.nanmean(np.diag(hist) / hist.sum(axis=1)) + + +def seg_eval(gt_labels, seg_preds, label2cat, logger=None): + """Semantic Segmentation Evaluation. + + Evaluate the result of the Semantic Segmentation. + + Args: + gt_labels (list[torch.Tensor]): Ground truth labels. + seg_preds (list[torch.Tensor]): Predtictions + label2cat (dict): Map from label to category. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmdet.utils.print_log()` for details. Default: None. + + Return: + dict[str, float]: Dict of results. + """ + assert len(seg_preds) == len(gt_labels) + + hist_list = [] + for i in range(len(seg_preds)): + hist_list.append( + fast_hist(seg_preds[i].numpy().astype(int), + gt_labels[i].numpy().astype(int), len(label2cat))) + iou = per_class_iou(sum(hist_list)) + miou = np.nanmean(iou) + acc = get_acc(sum(hist_list)) + acc_cls = get_acc_cls(sum(hist_list)) + + header = ['classes'] + for i in range(len(label2cat)): + header.append(label2cat[i]) + header.extend(['miou', 'acc', 'acc_cls']) + + ret_dict = dict() + table_columns = [['results']] + for i in range(len(label2cat)): + ret_dict[label2cat[i]] = float(iou[i]) + table_columns.append([f'{iou[i]:.4f}']) + ret_dict['miou'] = float(miou) + ret_dict['acc'] = float(acc) + ret_dict['acc_cls'] = float(acc_cls) + + table_columns.append([f'{miou:.4f}']) + table_columns.append([f'{acc:.4f}']) + table_columns.append([f'{acc_cls:.4f}']) + + table_data = [header] + table_rows = list(zip(*table_columns)) + table_data += table_rows + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) + + return ret_dict diff --git a/tests/test_metrics/test_seg_eval.py b/tests/test_metrics/test_seg_eval.py new file mode 100644 index 0000000000..d8850775ad --- /dev/null +++ b/tests/test_metrics/test_seg_eval.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest +import torch + +from mmdet3d.core.evaluation.seg_eval import seg_eval + + +def test_indoor_eval(): + if not torch.cuda.is_available(): + pytest.skip() + seg_preds = [ + torch.Tensor( + [0, 0, 1, 0, 2, 1, 3, 1, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3]) + ] + gt_labels = [ + torch.Tensor( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + ] + + label2cat = { + 0: 'car', + 1: 'bicycle', + 2: 'motorcycle', + 3: 'truck', + } + ret_value = seg_eval(gt_labels, seg_preds, label2cat) + + assert np.isclose(ret_value['car'], 0.428571429) + assert np.isclose(ret_value['bicycle'], 0.428571429) + assert np.isclose(ret_value['motorcycle'], 0.6666667) + assert np.isclose(ret_value['truck'], 0.6666667) + + assert np.isclose(ret_value['acc'], 0.7) + assert np.isclose(ret_value['acc_cls'], 0.7) + assert np.isclose(ret_value['miou'], 0.547619048)