Interpretable3D: An Ad-Hoc Interpretable Classifier for 3D Point Clouds
Tuo Feng, Ruijie Quan, Xiaohan Wang, Wenguan Wang, Yi Yang
This is the official implementation of "Interpretable3D: An Ad-Hoc Interpretable Classifier for 3D Point Clouds" (Accepted at AAAI 2024).
3D decision-critical tasks urgently require research on explanations to ensure system reliability and transparency. Extensive explanatory research has been conducted on 2D images, but there is a lack in the 3D field. Furthermore, the existing explanations for 3D models are post-hoc and can be misleading, as they separate explanations from the original model. To address these issues, we propose an ad-hoc interpretable classifier for 3D point clouds (i.e., Interpretable3D). As an intuitive case-based classifier, Interpretable3D can provide reliable ad-hoc explanations without any embarrassing nuances. It allows users to understand how queries are embedded within past observations in prototype sets. Interpretable3D has two iterative training steps: 1) updating one prototype with the mean of the embeddings within the same sub-class in Prototype Estimation, and 2) penalizing or rewarding the estimated prototypes in Prototype Optimization. The mean of embeddings has a clear statistical meaning, i.e., class sub-centers. Moreover, we update prototypes with their most similar observations in the last few epochs. Finally, Interpretable3D classifies new samples according to prototypes. We evaluate the performance of Interpretable3D on four popular point cloud models: DGCNN, PointNet2, PointMLP, and PointNeXt. Our Interpretable3D demonstrates comparable or superior performance compared to softmax-based black-box models in the tasks of 3D shape classification and part segmentation.
ModelNet40:
resampled ModelNet here and save in data/modelnet40_normal_resampled/
, or ModelNet here and save in data/modelnet40_ply_hdf5_2048/
.
ScanObjectNN: ScanObjectNN here and save in data/ScanObjectNN/main_split/
.
Note: We conduct experiments on the hardest variant of ScanObjectNN (PB_T50_RS).
To train Interpretable3D-M:
python train_XXXX_ip3d.py
To test Interpretable3D-M:
python test_XXXX_ip3d.py
Interpretable3D-M thoroughly updates the prototype with the mean of subclass centers obtained by online clustering.
Please refer to Pointnet_Pointnet2_pytorch for its installation.
Model | Accuracy |
---|---|
PointNet2_MSG (Pytorch with normal) | 92.8 |
Interpretable3D-M+PointNet2_MSG (Pytorch with normal) | 93.5 |
Interpretable3D-M+PointNet2_MSG (Pytorch with normal) (vote) | 93.7 |
Model | OA(%) | mAcc(%) |
---|---|---|
PointNet2_MSG (paper) | 79.1 | 77.6 |
Interpretable3D (paper) | 79.3 | 78.4 |
PointNet2_MSG (reprod.) | 79.9 | 77.1 |
Interpretable3D-M+PointNet2_MSG (reprod.) | 80.0 | 77.3 |
More code and experimental results will be gradually open-sourced.
We find that training the softmax classifier and Interpretable3D-M simultaneously leads to faster convergence and better results.
If you find the code useful in your research, please consider citing our paper:
@inproceedings{feng2024interpretable3D,
title={Interpretable3D: An Ad-Hoc Interpretable Classifier for 3D Point Clouds},
author={Feng, Tuo and Quan, Ruijie and Wang, Xiaohan and Wang, Wenguan, and Yang, Yi},
booktitle=AAAI,
year={2024}
}
Any comments, please email: feng.tuo@student.uts.edu.au.
We thank for the opensource codebases: DNC, ProtoSeg, Pointnet_Pointnet2_pytorch, PointNeXt and Cluster3DSeg.