Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support ScanNet semantic segmentation dataset #390

Merged
merged 43 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
85c15a4
remove max_num_point in ScanNet data preprocessing
Wuziyi616 Mar 23, 2021
5d9aca6
add config file for ScanNet semantic segmentation dataset
Wuziyi616 Mar 24, 2021
f4bbb00
modify NormalizePointsColor in pipeline
Wuziyi616 Mar 24, 2021
2adea6a
add visualization function for semantic segmentation
Wuziyi616 Mar 25, 2021
4412691
add ignore_index to semantic segmentation visualization function
Wuziyi616 Mar 25, 2021
d8ecd43
add ignore_index to semantic segmentation evaluation function
Wuziyi616 Mar 25, 2021
3490ca6
fix ignore_index bug in semantic segmentation evaluation function
Wuziyi616 Mar 25, 2021
fea1b62
add test function to check ignore_index assignment in PointSegClassMa…
Wuziyi616 Mar 25, 2021
dd7df61
fix slicing bug in BasePoints class and add unittest
Wuziyi616 Mar 26, 2021
e721eb1
add IndoorPatchPointSample class for indoor semantic segmentation dat…
Wuziyi616 Mar 26, 2021
943492b
modify LoadPointsFromFile class and its unittest to support point col…
Wuziyi616 Mar 26, 2021
b680f71
fix data path in unittest
Wuziyi616 Mar 26, 2021
c4afd15
add setter function for coord and attributes of BasePoint and modify …
Wuziyi616 Mar 26, 2021
0bdb48a
modify color normalization operation to work on BasePoint class
Wuziyi616 Mar 26, 2021
a44657b
add unittest for ScanNet semantic segmentation data loading pipeline
Wuziyi616 Mar 26, 2021
64e9584
fix ignore_index bug in seg_eval function
Wuziyi616 Mar 26, 2021
59d0f0b
add ScanNet semantic segmentation dataset and unittest
Wuziyi616 Mar 26, 2021
070cb83
modify config file for ScanNet semantic segmentation
Wuziyi616 Mar 27, 2021
b81aaef
fix visualization function and modify unittest
Wuziyi616 Mar 29, 2021
b666ef2
fix a typo in seg_eval.py
Wuziyi616 Mar 29, 2021
206a130
raise exception when semantic mask is not provided in train/eval data…
Wuziyi616 Mar 29, 2021
230e2f1
support custom computation of label weight for loss calculation
Wuziyi616 Mar 29, 2021
3e71b5f
modify seg_eval function to be more efficient
Wuziyi616 Mar 29, 2021
417101c
fix small bugs & change variable names for clarity & add more cases t…
Wuziyi616 Mar 30, 2021
1705ada
move room index resampling and label weight computation to data pre-p…
Wuziyi616 Mar 30, 2021
f9b975f
add option allowing user to determine whether to sub-sample point clouds
Wuziyi616 Mar 30, 2021
46e0133
fix typos & change .format to f-string & fix link in comment
Wuziyi616 Mar 30, 2021
deda71d
save all visualizations into .obj format for consistency
Wuziyi616 Mar 30, 2021
ea670db
infer num_classes from label2cat in eval_seg function
Wuziyi616 Mar 30, 2021
5ae21c5
add pre-computed room index and label weight for ScanNet dataset
Wuziyi616 Mar 30, 2021
b862eba
replace .ply with .obj in unittests and documents
Wuziyi616 Mar 30, 2021
3591efa
add TODO in case data is on ceph
Wuziyi616 Mar 30, 2021
d11dc79
add base dataset for all semantic segmentation tasks & add ScanNet da…
Wuziyi616 Mar 31, 2021
e0604af
rename class for consistency
Wuziyi616 Mar 31, 2021
996afd8
fix minor typos in comment
Wuziyi616 Apr 1, 2021
8940bbf
move Custom3DSegDataset to a new file
Wuziyi616 Apr 6, 2021
35bfb35
modify BasePoint setter function to enable attribute adding
Wuziyi616 Apr 6, 2021
c4c148d
add unittest for NormalizePointsColor and fix small bugs
Wuziyi616 Apr 6, 2021
bd06d3f
fix unittest for BasePoints
Wuziyi616 Apr 6, 2021
253bf53
modify ScanNet data pre-processing scripts
Wuziyi616 Apr 6, 2021
460ed04
change ignore_idx to -1 in seg_eval function
Wuziyi616 Apr 6, 2021
71ff97b
remove sliding inference from PatchSample function and modify unittest
Wuziyi616 Apr 6, 2021
dd0f856
remove PatchSample from scannet seg test_pipeline
Wuziyi616 Apr 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions configs/_base_/datasets/scannet_seg-3d-20class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# dataset settings
dataset_type = 'ScanNetSegDataset'
data_root = './data/scannet/'
class_names = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
'curtain', 'refrigerator', 'showercurtrain', 'toilet', 'sink',
'bathtub', 'otherfurniture')
num_points = 8192
train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)),
dict(
type='IndoorPatchPointSample',
num_points=num_points,
block_size=1.5,
sample_rate=1.0,
ignore_index=len(class_names),
use_normalized_coord=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='NormalizePointsColor', color_mean=None),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points'])
]

data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_train.pkl',
pipeline=train_pipeline,
classes=class_names,
test_mode=False,
ignore_index=len(class_names),
scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy',
label_weight=data_root + 'seg_info/train_label_weight.npy'),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names)),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
test_mode=True,
ignore_index=len(class_names)))
9 changes: 7 additions & 2 deletions data/scannet/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
### Prepare ScanNet Data
### Prepare ScanNet Data for Indoor Detection or Segmentation Task
We follow the procedure in [votenet](https://github.com/facebookresearch/votenet/).

1. Download ScanNet v2 data [HERE](https://github.com/ScanNet/ScanNet). Link or move the 'scans' folder to this level of directory.

2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`.
2. In this directory, extract point clouds and annotations by running `python batch_load_scannet_data.py`. Add the `--max_num_point 50000` flag if you only use the ScanNet data for the detection task. It will downsample the scenes to less points.

3. Enter the project root directory, generate training data by running
```bash
Expand Down Expand Up @@ -33,6 +33,11 @@ scannet
│ ├── xxxxx.bin
├── semantic_mask
│ ├── xxxxx.bin
├── seg_info
│ ├── train_label_weight.npy
│ ├── train_resampled_scene_idxs.npy
│ ├── val_label_weight.npy
│ ├── val_resampled_scene_idxs.npy
├── scannet_infos_train.pkl
├── scannet_infos_val.pkl

Expand Down
15 changes: 8 additions & 7 deletions data/scannet/batch_load_scannet_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def export_one_scan(scan_name, output_filename_prefix, max_num_point,
instance_bboxes = instance_bboxes[bbox_mask, :]
print(f'Num of care instances: {instance_bboxes.shape[0]}')

N = mesh_vertices.shape[0]
if N > max_num_point:
choices = np.random.choice(N, max_num_point, replace=False)
mesh_vertices = mesh_vertices[choices, :]
semantic_labels = semantic_labels[choices]
instance_labels = instance_labels[choices]
if max_num_point is not None:
N = mesh_vertices.shape[0]
if N > max_num_point:
choices = np.random.choice(N, max_num_point, replace=False)
mesh_vertices = mesh_vertices[choices, :]
semantic_labels = semantic_labels[choices]
instance_labels = instance_labels[choices]

np.save(f'{output_filename_prefix}_vert.npy', mesh_vertices)
np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
Expand Down Expand Up @@ -88,7 +89,7 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--max_num_point',
default=50000,
default=None,
help='The maximum number of the points.')
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
'--output_folder',
Expand Down
2 changes: 1 addition & 1 deletion docs/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Optional arguments:
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset. Typically we default to use official metrics for evaluation on different datasets, so it can be simply set to `mAP` as a placeholder, which applies to nuScenes, Lyft, ScanNet and SUNRGBD. For KITTI, if we only want to evaluate the 2D detection performance, we can simply set the metric to `img_bbox` (unstable, stay tuned). For Waymo, we provide both KITTI-style evaluation (unstable) and Waymo-style official protocol, corresponding to metric `kitti` and `waymo` respectively. We recommend to use the default official metric for stable performance and fair comparison with other methods.
- `--show`: If specified, detection results will be plotted in the silient mode. It is only applicable to single GPU testing and used for debugging and visualization. This should be used with `--show-dir`.
- `--show-dir`: If specified, detection results will be plotted on the `***_points.obj` and `***_pred.ply` files in the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.
- `--show-dir`: If specified, detection results will be plotted on the `***_points.obj` and `***_pred.obj` files in the specified directory. It is only applicable to single GPU testing and used for debugging and visualization. You do NOT need a GUI available in your environment for using this option.

Examples:

Expand Down
6 changes: 3 additions & 3 deletions docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ To see the SUNRGBD, ScanNet or KITTI points and detection results, you can run t
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --show --show-dir ${SHOW_DIR}
```

Aftering running this command, plotted results **_\_points.obj and _**\_pred.ply files in `${SHOW_DIR}`.
Aftering running this command, plotted results **_\_points.obj and _**\_pred.obj files in `${SHOW_DIR}`.

To see the points, detection results and ground truth of SUNRGBD, ScanNet or KITTI during evaluation time, you can run the following command

```bash
python tools/test.py ${CONFIG_FILE} ${CKPT_PATH} --eval 'mAP' --options 'show=True' 'out_dir=${SHOW_DIR}'
```

After running this command, you will obtain **_\_points.obj, _**\_pred.ply files and \*\*\*\_gt.ply in `${SHOW_DIR}`. When `show` is enabled, [Open3D](http://www.open3d.org/) will be used to visualize the results online. You need to set `show=False` while running test in remote server withou GUI.
After running this command, you will obtain **_\_points.obj, _**\_pred.obj files and \*\*\*\_gt.obj in `${SHOW_DIR}`. When `show` is enabled, [Open3D](http://www.open3d.org/) will be used to visualize the results online. You need to set `show=False` while running test in remote server withou GUI.

As for offline visualization, you will have two options.
To visualize the results with `Open3D` backend, you can run the following command
Expand All @@ -76,7 +76,7 @@ python tools/misc/visualize_results.py ${CONFIG_FILE} --result ${RESULTS_PATH} -

![Open3D_visualization](../resources/open3d_visual.gif)

Or you can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the these files under `${SHOW_DIR}` to see the 3D detection output. Specifically, open `***_points.obj` to see the input point cloud and open `***_pred.ply` to see the predicted 3D bounding boxes. This allows the inference and results generation be done in remote server and the users can open them on their host with GUI.
Or you can use 3D visualization software such as the [MeshLab](http://www.meshlab.net/) to open the these files under `${SHOW_DIR}` to see the 3D detection output. Specifically, open `***_points.obj` to see the input point cloud and open `***_pred.obj` to see the predicted 3D bounding boxes. This allows the inference and results generation be done in remote server and the users can open them on their host with GUI.

**Notice**: The visualization API is a little unstable since we plan to refactor these parts together with MMDetection in the future.

Expand Down
35 changes: 22 additions & 13 deletions mmdet3d/core/evaluation/seg_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,37 @@ def get_acc_cls(hist):
return np.nanmean(np.diag(hist) / hist.sum(axis=1))


def seg_eval(gt_labels, seg_preds, label2cat, logger=None):
"""Semantic Segmentation Evaluation.
def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
"""Semantic Segmentation Evaluation.

Evaluate the result of the Semantic Segmentation.
Evaluate the result of the Semantic Segmentation.

Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predtictions
label2cat (dict): Map from label to category.
logger (logging.Logger | str | None): The way to print the mAP
Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predictions.
label2cat (dict): Map from label to category name.
ignore_index (int): Index that will be ignored in evaluation.
logger (logging.Logger | str | None): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.

Return:
Returns:
dict[str, float]: Dict of results.
"""
assert len(seg_preds) == len(gt_labels)
num_classes = len(label2cat)

hist_list = []
for i in range(len(seg_preds)):
hist_list.append(
fast_hist(seg_preds[i].numpy().astype(int),
gt_labels[i].numpy().astype(int), len(label2cat)))
for i in range(len(gt_labels)):
gt_seg = gt_labels[i].clone().numpy().astype(np.int)
pred_seg = seg_preds[i].clone().numpy().astype(np.int)

# filter out ignored points
pred_seg[gt_seg == ignore_index] = -1
gt_seg[gt_seg == ignore_index] = -1

# calculate one instance result
hist_list.append(fast_hist(pred_seg, gt_seg, num_classes))

iou = per_class_iou(sum(hist_list))
miou = np.nanmean(iou)
acc = get_acc(sum(hist_list))
Expand Down
71 changes: 64 additions & 7 deletions mmdet3d/core/points/base_points.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import warnings
from abc import abstractmethod


Expand Down Expand Up @@ -46,6 +47,17 @@ def coord(self):
"""torch.Tensor: Coordinates of each point with size (N, 3)."""
return self.tensor[:, :3]

@coord.setter
def coord(self, tensor):
"""Set the coordinates of each point."""
try:
tensor = tensor.reshape(self.shape[0], 3)
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
self.tensor[:, :3] = tensor

@property
def height(self):
"""torch.Tensor: A vector with height of each point."""
Expand All @@ -55,6 +67,27 @@ def height(self):
else:
return None

@height.setter
def height(self, tensor):
"""Set the height of each point."""
try:
tensor = tensor.reshape(self.shape[0])
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
if self.attribute_dims is not None and \
'height' in self.attribute_dims.keys():
self.tensor[:, self.attribute_dims['height']] = tensor
else:
# add height attribute
if self.attribute_dims is None:
self.attribute_dims = dict()
attr_dim = self.shape[1]
self.tensor = torch.cat([self.tensor, tensor.unsqueeze(1)], dim=1)
self.attribute_dims.update(dict(height=attr_dim))
self.points_dim += 1

@property
def color(self):
"""torch.Tensor: A vector with color of each point."""
Expand All @@ -64,6 +97,30 @@ def color(self):
else:
return None

@color.setter
def color(self, tensor):
"""Set the color of each point."""
try:
tensor = tensor.reshape(self.shape[0], 3)
except (RuntimeError, ValueError): # for torch.Tensor and np.ndarray
raise ValueError(f'got unexpected shape {tensor.shape}')
if tensor.max() >= 256 or tensor.min() < 0:
warnings.warn('point got color value beyond [0, 255]')
if not isinstance(tensor, torch.Tensor):
tensor = self.tensor.new_tensor(tensor)
if self.attribute_dims is not None and \
'color' in self.attribute_dims.keys():
self.tensor[:, self.attribute_dims['color']] = tensor
else:
# add color attribute
if self.attribute_dims is None:
self.attribute_dims = dict()
attr_dim = self.shape[1]
self.tensor = torch.cat([self.tensor, tensor], dim=1)
self.attribute_dims.update(
dict(color=[attr_dim, attr_dim + 1, attr_dim + 2]))
self.points_dim += 3

@property
def shape(self):
"""torch.Shape: Shape of points."""
Expand Down Expand Up @@ -136,8 +193,8 @@ def translate(self, trans_vector):
trans_vector.shape[1] == 3
else:
raise NotImplementedError(
'Unsupported translation vector of shape {}'.format(
trans_vector.shape))
f'Unsupported translation vector of shape {trans_vector.shape}'
)
self.tensor[:, :3] += trans_vector

def in_range_3d(self, point_range):
Expand Down Expand Up @@ -233,8 +290,8 @@ def __getitem__(self, item):
elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
start = 0 if item[1].start is None else item[1].start
stop = self.tensor.shape[1] + \
1 if item[1].stop is None else item[1].stop
stop = self.tensor.shape[1] if \
item[1].stop is None else item[1].stop
step = 1 if item[1].step is None else item[1].step
item = list(item)
item[1] = list(range(start, stop, step))
Expand All @@ -246,9 +303,9 @@ def __getitem__(self, item):
if self.attribute_dims is not None:
attribute_dims = self.attribute_dims.copy()
for key in self.attribute_dims.keys():
cur_attribute_dim = attribute_dims[key]
if isinstance(cur_attribute_dim, int):
cur_attribute_dims = [cur_attribute_dim]
cur_attribute_dims = attribute_dims[key]
if isinstance(cur_attribute_dims, int):
cur_attribute_dims = [cur_attribute_dims]
intersect_attr = list(
set(cur_attribute_dims).intersection(set(keep_dims)))
if len(intersect_attr) == 1:
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/core/visualizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .show_result import show_result
from .show_result import show_result, show_seg_result

__all__ = ['show_result']
__all__ = ['show_result', 'show_seg_result']
Loading