Unsupervised Point Cloud Object Co-segmentation by Co-Contrastive learning and Mutual Attention Sampling
This repository is the implementation of ICCV 2021 paper (Oral): Unsupervised Point Cloud Object Co-segmentation by using Co-Contrastive learning and Mutual Attention Sampling.
This paper presents a new task, point cloud object cosegmentation, aiming to segment the common 3D objects in a set of point clouds. We formulate this task as an object point sampling problem, and develop two techniques, the mutual attention module and co-contrastive learning, to enable it. The proposed method employs two point samplers based on deep neural networks, the object sampler and the background sampler. The former targets at sampling points of common objects while the latter focuses on the rest. The mutual attention module explores point-wise correlation across point clouds. It is embedded in both samplers and can identify points with strong cross-cloud correlation from the rest. After extracting features for points selected by the two samplers, we optimize the networks by developing the co-contrastive loss, which minimizes feature discrepancy of the estimated object points while maximizing feature separation between the estimated object and background points. Our method works on point clouds of an arbitrary object class. It is end-to-end trainable and does not need point-level annotations. It is evaluated on the ScanObjectNN and S3DIS datasets and achieves promising results.
Input | |||
---|---|---|---|
Output |
We strongly recommand using the Docker image provided by SampleNet Lang, et al. CVPR'2020.
To install requirements:
pip install -r requirements.txt
Download the ScanObjectNN here and S3DIS here. We follow JSIS3D to process the raw S3DIS first.
And run the pre-process data to generate the S3DIS object dataset.
python data_preprocess/parse_data.py
We utilize a classification model pre-trained on ModelNet40 as features extractor. You can run pretrain.py
for training. Or find the pre-trained weight here link
To train the model in the paper, run these commands, obj
from 1 to 14 stands for each object category in ScanObjectNN:
python train.py --config=configs/scanobj.yaml --obj=1
Run the trained model for inference
python test.py --config=work_dirs/raw/scanobj/chair.yaml
Run the command to genereate the GIF file in the README.md. Please note that only open3D with local monitor is supported.
python visualize.py
If you find our work useful in your research, please consider citing our paper:
@inproceedings{yang2021unsupervised,
title={Unsupervised Point Cloud Object Co-segmentation by Co-contrastive Learning and Mutual Attention Sampling},
author={Yang, Cheng-Kun and Chuang, Yung-Yu and Lin, Yen-Yu},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={7335--7344},
year={2021}
}