Skip to content

Commit

Permalink
[Enhance] Remove useless param in 3D Seg Dataset classes (#607)
Browse files Browse the repository at this point in the history
* remove useless label_weight args in dataset class

* modify unit tests

* update compatibility docs
  • Loading branch information
Wuziyi616 authored Jun 1, 2021
1 parent d5bddd2 commit 6c5a320
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 140 deletions.
4 changes: 0 additions & 4 deletions configs/_base_/datasets/s3dis_seg-3d-13class.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@
scene_idxs=[
data_root + f'seg_info/Area_{i}_resampled_scene_idxs.npy'
for i in train_area
],
label_weights=[
data_root + f'seg_info/Area_{i}_label_weight.npy'
for i in train_area
]),
val=dict(
type=dataset_type,
Expand Down
3 changes: 1 addition & 2 deletions configs/_base_/datasets/scannet_seg-3d-20class.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@
classes=class_names,
test_mode=False,
ignore_index=len(class_names),
scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy',
label_weight=data_root + 'seg_info/train_label_weight.npy'),
scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy'),
val=dict(
type=dataset_type,
data_root=data_root,
Expand Down
6 changes: 6 additions & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

This document provides detailed descriptions of the BC-breaking changes in MMDetection3D.

## MMDetection3D 0.15.0

### Dataset class for 3D segmentation task

We remove a useless parameter `label_weight` from segmentation datasets including `Custom3DSegDataset`, `ScanNetSegDataset` and `S3DISSegDataset` since this weight is utilized in the loss function of model class. Please modify the code as well as the config files accordingly if you use or inherit from these codes.

## MMDetection3D 0.14.0

### ScanNet data pre-processing
Expand Down
37 changes: 7 additions & 30 deletions mmdet3d/datasets/custom_3d_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ class Custom3DSegDataset(Dataset):
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
label_weight (np.ndarray | str, optional): Precomputed weight to \
balance loss calculation. If None is given, use equal weighting.
Defaults to None.
"""
# names of all classes data used for the task
CLASSES = None
Expand All @@ -63,8 +60,7 @@ def __init__(self,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weight=None):
scene_idxs=None):
super().__init__()
self.data_root = data_root
self.ann_file = ann_file
Expand All @@ -79,8 +75,7 @@ def __init__(self,
self.ignore_index = len(self.CLASSES) if \
ignore_index is None else ignore_index

self.scene_idxs, self.label_weight = \
self.get_scene_idxs_and_label_weight(scene_idxs, label_weight)
self.scene_idxs = self.get_scene_idxs(scene_idxs)
self.CLASSES, self.PALETTE = \
self.get_classes_and_palette(classes, palette)

Expand Down Expand Up @@ -250,26 +245,16 @@ def get_classes_and_palette(self, classes=None, palette=None):
for cls_name in class_names
]

# also need to modify self.label_weight
self.label_weight = np.array([
self.label_weight[self.CLASSES.index(cls_name)]
for cls_name in class_names
]).astype(np.float32)

return class_names, palette

def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
We sample more times for scenes with more points.
"""
if self.test_mode:
# when testing, we load one whole scene every time
# and we don't need label weight for loss calculation
return np.arange(len(self.data_infos)).astype(np.int32), \
np.ones(len(self.CLASSES)).astype(np.float32)
return np.arange(len(self.data_infos)).astype(np.int32)

# we may need to re-sample different scenes according to scene_idxs
# this is necessary for indoor scene segmentation such as ScanNet
Expand All @@ -280,15 +265,7 @@ def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
else:
scene_idxs = np.array(scene_idxs)

if label_weight is None:
# we don't used label weighting in training
label_weight = np.ones(len(self.CLASSES))
elif isinstance(label_weight, str):
label_weight = np.load(label_weight)
else:
label_weight = np.array(label_weight)

return scene_idxs.astype(np.int32), label_weight.astype(np.float32)
return scene_idxs.astype(np.int32)

def format_results(self,
outputs,
Expand Down
65 changes: 12 additions & 53 deletions mmdet3d/datasets/s3dis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class _S3DISSegDataset(Custom3DSegDataset):
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
label_weight (np.ndarray | str, optional): Precomputed weight to \
balance loss calculation. If None is given, compute from data.
Defaults to None.
"""
CLASSES = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door',
'table', 'chair', 'sofa', 'bookcase', 'board', 'clutter')
Expand All @@ -66,8 +63,7 @@ def __init__(self,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weight=None):
scene_idxs=None):

super().__init__(
data_root=data_root,
Expand All @@ -78,8 +74,7 @@ def __init__(self,
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs,
label_weight=label_weight)
scene_idxs=scene_idxs)

def get_ann_info(self, index):
"""Get annotation info according to the given index.
Expand Down Expand Up @@ -153,21 +148,17 @@ def show(self, results, out_dir, show=True, pipeline=None):
pred_sem_mask, out_dir, file_name,
np.array(self.PALETTE), self.ignore_index, show)

def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
We sample more times for scenes with more points.
"""
# when testing, we load one whole scene every time
# and we don't need label weight for loss calculation
if not self.test_mode and scene_idxs is None:
raise NotImplementedError(
'please provide re-sampled scene indexes for training')

return super().get_scene_idxs_and_label_weight(scene_idxs,
label_weight)
return super().get_scene_idxs(scene_idxs)


@DATASETS.register_module()
Expand All @@ -178,7 +169,7 @@ class S3DISSegDataset(_S3DISSegDataset):
This class serves as the API for experiments on the S3DIS Dataset.
It wraps the provided datasets of different areas.
We don't use `mmdet.datasets.dataset_wrappers.ConcatDataset` because we
need to concat the `scene_idxs` and `label_weights` of different areas.
need to concat the `scene_idxs` of different areas.
Please refer to the `google form <https://docs.google.com/forms/d/e/1FAIpQL
ScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1>`_ for
Expand All @@ -203,9 +194,6 @@ class S3DISSegDataset(_S3DISSegDataset):
scene_idxs (list[np.ndarray] | list[str], optional): Precomputed index
to load data. For scenes with many points, we may sample it several
times. Defaults to None.
label_weights (list[np.ndarray] | list[str], optional): Precomputed
weight to balance loss calculation. If None is given, compute from
data. Defaults to None.
"""

def __init__(self,
Expand All @@ -217,14 +205,11 @@ def __init__(self,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weights=None):
scene_idxs=None):

# make sure that ann_files, scene_idxs and label_weights have same len
# make sure that ann_files and scene_idxs have same length
ann_files = self._check_ann_files(ann_files)
scene_idxs = self._check_scene_idxs(scene_idxs, len(ann_files))
label_weights = self._check_label_weights(label_weights,
len(ann_files))

# initialize some attributes as datasets[0]
super().__init__(
Expand All @@ -236,8 +221,7 @@ def __init__(self,
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs[0],
label_weight=label_weights[0])
scene_idxs=scene_idxs[0])

datasets = [
_S3DISSegDataset(
Expand All @@ -249,14 +233,12 @@ def __init__(self,
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs[i],
label_weight=label_weights[i]) for i in range(len(ann_files))
scene_idxs=scene_idxs[i]) for i in range(len(ann_files))
]

# data_infos, scene_idxs, label_weight need to be concat
# data_infos and scene_idxs need to be concat
self.concat_data_infos([dst.data_infos for dst in datasets])
self.concat_scene_idxs([dst.scene_idxs for dst in datasets])
self.concat_label_weight([dst.label_weight for dst in datasets])

# set group flag for the sampler
if not self.test_mode:
Expand Down Expand Up @@ -287,15 +269,6 @@ def concat_scene_idxs(self, scene_idxs):
[self.scene_idxs, one_scene_idxs + offset]).astype(np.int32)
offset = np.unique(self.scene_idxs).max() + 1

def concat_label_weight(self, label_weights):
"""Concat label_weight from several datasets to form self.label_weight.
Args:
label_weights (list[np.ndarray])
"""
# TODO: simply average them?
self.label_weight = np.array(label_weights).mean(0).astype(np.float32)

@staticmethod
def _duplicate_to_list(x, num):
"""Repeat x `num` times to form a list."""
Expand All @@ -321,17 +294,3 @@ def _check_scene_idxs(self, scene_idx, num):
return scene_idx
# single idx
return self._duplicate_to_list(scene_idx, num)

def _check_label_weights(self, label_weight, num):
"""Make label_weights as list/tuple."""
if label_weight is None:
return self._duplicate_to_list(label_weight, num)
# label_weight could be str, np.ndarray, list or tuple
if isinstance(label_weight, str): # str
return self._duplicate_to_list(label_weight, num)
if isinstance(label_weight[0], str): # list of str
return label_weight
if isinstance(label_weight[0], (list, tuple, np.ndarray)): # list of w
return label_weight
# single weight
return self._duplicate_to_list(label_weight, num)
21 changes: 6 additions & 15 deletions mmdet3d/datasets/scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,6 @@ class ScanNetSegDataset(Custom3DSegDataset):
scene_idxs (np.ndarray | str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
label_weight (np.ndarray | str, optional): Precomputed weight to \
balance loss calculation. If None is given, compute from data.
Defaults to None.
"""
CLASSES = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
Expand Down Expand Up @@ -271,8 +268,7 @@ def __init__(self,
modality=None,
test_mode=False,
ignore_index=None,
scene_idxs=None,
label_weight=None):
scene_idxs=None):

super().__init__(
data_root=data_root,
Expand All @@ -283,8 +279,7 @@ def __init__(self,
modality=modality,
test_mode=test_mode,
ignore_index=ignore_index,
scene_idxs=scene_idxs,
label_weight=label_weight)
scene_idxs=scene_idxs)

def get_ann_info(self, index):
"""Get annotation info according to the given index.
Expand Down Expand Up @@ -358,21 +353,17 @@ def show(self, results, out_dir, show=True, pipeline=None):
pred_sem_mask, out_dir, file_name,
np.array(self.PALETTE), self.ignore_index, show)

def get_scene_idxs_and_label_weight(self, scene_idxs, label_weight):
"""Compute scene_idxs for data sampling and label weight for loss \
calculation.
def get_scene_idxs(self, scene_idxs):
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points. Label_weight is
inversely proportional to number of class points.
We sample more times for scenes with more points.
"""
# when testing, we load one whole scene every time
# and we don't need label weight for loss calculation
if not self.test_mode and scene_idxs is None:
raise NotImplementedError(
'please provide re-sampled scene indexes for training')

return super().get_scene_idxs_and_label_weight(scene_idxs,
label_weight)
return super().get_scene_idxs(scene_idxs)

def format_results(self, results, txtfile_prefix=None):
r"""Format the results to txt file. Refer to `ScanNet documentation
Expand Down
Loading

0 comments on commit 6c5a320

Please sign in to comment.