-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
167 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
from .indoor_eval import indoor_eval | ||
from .kitti_utils import kitti_eval, kitti_eval_coco_style | ||
from .lyft_eval import lyft_eval | ||
from .seg_eval import seg_eval | ||
|
||
__all__ = ['kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval'] | ||
__all__ = [ | ||
'kitti_eval_coco_style', 'kitti_eval', 'indoor_eval', 'lyft_eval', | ||
'seg_eval' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
from mmcv.utils import print_log | ||
from terminaltables import AsciiTable | ||
|
||
|
||
def fast_hist(preds, labels, num_classes): | ||
"""Compute the confusion matrix for every batch. | ||
Args: | ||
preds (np.ndarray): Prediction labels of points with shape of | ||
(num_points, ). | ||
labels (np.ndarray): Ground truth labels of points with shape of | ||
(num_points, ). | ||
num_classes (int): number of classes | ||
Returns: | ||
np.ndarray: Calculated confusion matrix. | ||
""" | ||
|
||
k = (labels >= 0) & (labels < num_classes) | ||
bin_count = np.bincount( | ||
num_classes * labels[k].astype(int) + preds[k], | ||
minlength=num_classes**2) | ||
return bin_count[:num_classes**2].reshape(num_classes, num_classes) | ||
|
||
|
||
def per_class_iou(hist): | ||
"""Compute the per class iou. | ||
Args: | ||
hist(np.ndarray): Overall confusion martix | ||
(num_classes, num_classes ). | ||
Returns: | ||
np.ndarray: Calculated per class iou | ||
""" | ||
|
||
return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) | ||
|
||
|
||
def get_acc(hist): | ||
"""Compute the overall accuracy. | ||
Args: | ||
hist(np.ndarray): Overall confusion martix | ||
(num_classes, num_classes ). | ||
Returns: | ||
float: Calculated overall acc | ||
""" | ||
|
||
return np.diag(hist).sum() / hist.sum() | ||
|
||
|
||
def get_acc_cls(hist): | ||
"""Compute the class average accuracy. | ||
Args: | ||
hist(np.ndarray): Overall confusion martix | ||
(num_classes, num_classes ). | ||
Returns: | ||
float: Calculated class average acc | ||
""" | ||
|
||
return np.nanmean(np.diag(hist) / hist.sum(axis=1)) | ||
|
||
|
||
def seg_eval(gt_labels, seg_preds, label2cat, logger=None): | ||
"""Semantic Segmentation Evaluation. | ||
Evaluate the result of the Semantic Segmentation. | ||
Args: | ||
gt_labels (list[torch.Tensor]): Ground truth labels. | ||
seg_preds (list[torch.Tensor]): Predtictions | ||
label2cat (dict): Map from label to category. | ||
logger (logging.Logger | str | None): The way to print the mAP | ||
summary. See `mmdet.utils.print_log()` for details. Default: None. | ||
Return: | ||
dict[str, float]: Dict of results. | ||
""" | ||
assert len(seg_preds) == len(gt_labels) | ||
|
||
hist_list = [] | ||
for i in range(len(seg_preds)): | ||
hist_list.append( | ||
fast_hist(seg_preds[i].numpy().astype(int), | ||
gt_labels[i].numpy().astype(int), len(label2cat))) | ||
iou = per_class_iou(sum(hist_list)) | ||
miou = np.nanmean(iou) | ||
acc = get_acc(sum(hist_list)) | ||
acc_cls = get_acc_cls(sum(hist_list)) | ||
|
||
header = ['classes'] | ||
for i in range(len(label2cat)): | ||
header.append(label2cat[i]) | ||
header.extend(['miou', 'acc', 'acc_cls']) | ||
|
||
ret_dict = dict() | ||
table_columns = [['results']] | ||
for i in range(len(label2cat)): | ||
ret_dict[label2cat[i]] = float(iou[i]) | ||
table_columns.append([f'{iou[i]:.4f}']) | ||
ret_dict['miou'] = float(miou) | ||
ret_dict['acc'] = float(acc) | ||
ret_dict['acc_cls'] = float(acc_cls) | ||
|
||
table_columns.append([f'{miou:.4f}']) | ||
table_columns.append([f'{acc:.4f}']) | ||
table_columns.append([f'{acc_cls:.4f}']) | ||
|
||
table_data = [header] | ||
table_rows = list(zip(*table_columns)) | ||
table_data += table_rows | ||
table = AsciiTable(table_data) | ||
table.inner_footing_row_border = True | ||
print_log('\n' + table.table, logger=logger) | ||
|
||
return ret_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from mmdet3d.core.evaluation.seg_eval import seg_eval | ||
|
||
|
||
def test_indoor_eval(): | ||
if not torch.cuda.is_available(): | ||
pytest.skip() | ||
seg_preds = [ | ||
torch.Tensor( | ||
[0, 0, 1, 0, 2, 1, 3, 1, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3]) | ||
] | ||
gt_labels = [ | ||
torch.Tensor( | ||
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) | ||
] | ||
|
||
label2cat = { | ||
0: 'car', | ||
1: 'bicycle', | ||
2: 'motorcycle', | ||
3: 'truck', | ||
} | ||
ret_value = seg_eval(gt_labels, seg_preds, label2cat) | ||
|
||
assert np.isclose(ret_value['car'], 0.428571429) | ||
assert np.isclose(ret_value['bicycle'], 0.428571429) | ||
assert np.isclose(ret_value['motorcycle'], 0.6666667) | ||
assert np.isclose(ret_value['truck'], 0.6666667) | ||
|
||
assert np.isclose(ret_value['acc'], 0.7) | ||
assert np.isclose(ret_value['acc_cls'], 0.7) | ||
assert np.isclose(ret_value['miou'], 0.547619048) |