This is the official implementation of Gumbel Optimised Loss for Long-tailed Instance Segmentation for ECCV2022 accepted paper.
Major advancements have been made in the field of object detection and segmentation recently. However, when it comes to rare categories, the state-of-the-art methods fail to detect them, resulting in a significant performance gap between rare and frequent categories. In this paper, we identify that Sigmoid or Softmax functions used in deep detectors are a major reason for low performance and are suboptimal for long-tailed detection and segmentation. To address this, we develop a Gumbel Optimized Loss (GOL), for long-tailed detection and segmentation. It aligns with the Gumbel distribution of rare classes in imbalanced datasets, considering the fact that most classes in long-tailed detection have low expected probability. The proposed GOL significantly outperforms the best state-of-the-art method by 1.1% on AP, and boosts the overall segmentation by 9.0% and detection by 8.0%, particularly improving detection of rare classes by 20.3%, compared to Mask-RCNN, on LVIS dataset.
Gumbel Activation using (M)ask-RCNN, (R)esnet,Resne(X)t, (C)ascade Mask-RCNN and (H)ybrid Task Cascade.
def gumbel_cross_entropy(pred,
label,reduction):
"""Calculate the Gumbel CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction.
label (torch.Tensor): one-hot encoded
Returns:
torch.Tensor: The calculated loss.
"""
pred=torch.clamp(pred,min=-4,max=10)
pestim= 1/(torch.exp(torch.exp(-(pred))))
loss = F.binary_cross_entropy(
pestim, label.float(), reduction=reduction)
loss=torch.clamp(loss,min=0,max=20)
return loss
- python==3.8.12
- torch==1.7.1
- torchvision==0.8.2
- mmdet==2.21.0
- lvis
- Tested on CUDA 10.2 and RHEL 8 system
conda create --name mmdet pytorch=1.7.1 -y
conda activate mmdet
- Install dependency packages
conda install torchvision -y
conda install pandas scipy -y
conda install opencv -y
- Install MMDetection
pip install openmim
mim install mmdet==2.21.0
- Clone this repo
git clone https://github.com/kostas1515/GOL.git
cd GOL
- Create data directory, download COCO 2017 datasets at https://cocodataset.org/#download (2017 Train images [118K/18GB], 2017 Val images [5K/1GB], 2017 Train/Val annotations [241MB]) and extract the zip files:
mkdir data
cd data
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
#download and unzip LVIS annotations
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip
wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip
- modify mmdetection/configs/base/datasets/lvis_v1_instance.py and make sure data_root variable points to the above data directory, e.g., data_root = '<user_path>'
./tools/dist_train.sh ./configs/<experiment>/<variant.py> <#GPUs>
E.g: To train GOL on 4 GPUs use:
./tools/dist_train.sh ./configs/gol/droploss_normed_mask_r50_rfs_4x4_2x_gumbel.py 4
To test GOL:
./tools/dist_test.sh ./experiments/droploss_normed_mask_rcnn_r50_rfs_4x4_2x_gumbel/droploss_normed_mask_r50_rfs_4x4_2x_gumbel.py ./experiments/droploss_normed_mask_r50_rfs_4x4_2x_gumbel/latest.pth 4 --eval bbox segm
./tools/dist_train.sh ./configs/activations/r50_4x4_1x.py <#GPUs>
./tools/dist_train.sh ./configs/activations/r50_4x4_1x_softmax.py <#GPUs>
./tools/dist_train.sh ./configs/activations/gumbel/gumbel_r50_4x4_1x.py <#GPUs>
It will give a Table similar to this:
Method | AP | APr | APc | APf | APb |
---|---|---|---|---|---|
Sigmoid | 16.4 | 0.8 | 12.7 | 27.3 | 17.2 |
Softmax | 15.2 | 0.0 | 10.6 | 26.9 | 16.1 |
Gumbel | 19.0 | 4.9 | 16.8 | 27.6 | 19.1 |
Method | AP | APr | APc | APf | APb | Model | Output |
---|---|---|---|---|---|---|---|
GOL_r50_v0.5 | 29.5 | 22.5 | 31.3 | 30.1 | 28.2 | weights | log|config |
GOL_r50_v1 | 27.7 | 21.4 | 27.7 | 30.4 | 27.5 | weights | log|config |
GOL_r101_v1 | 29.0 | 22.8 | 29.0 | 31.7 | 29.2 | weights | log|config |
@inproceedings{alexandridis2022long,
title={Long-tailed Instance Segmentation using Gumbel Optimized Loss},
author={Alexandridis, Konstantinos Panagiotis and Deng, Jiankang and Nguyen, Anh and Luo, Shan},
booktitle={European Conference on Computer Vision},
pages={353--369},
year={2022},
organization={Springer}
}