Skip to content

JiaDingCN/GeminiFusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GeminiFusion for Multimodal Segementation on NYUDv2 & SUN RGBD Dataset (ICML 2024)

PWC PWC PWC

This is the official implementation of our paper "GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer".

Authors: Ding Jia, Jianyuan Guo, Kai Han, Han Wu, Chao Zhang, Chang Xu, Xinghao Chen

Code List

We have applied our GeminiFusion to different tasks and datasets:

Introduction

We propose GeminiFusion, a pixel-wise fusion approach that capitalizes on aligned cross-modal representations. GeminiFusion elegantly combines intra-modal and inter-modal attentions, dynamically integrating complementary information across modalities. We employ a layer-adaptive noise to adaptively control their interplay on a per-layer basis, thereby achieving a harmonized fusion process. Notably, GeminiFusion maintains linear complexity with respect to the number of input tokens, ensuring this multimodal framework operates with efficiency comparable to unimodal networks. Comprehensive evaluations demonstrate the superior performance of our GeminiFusion against leading-edge techniques.

Framework

geminifusion_framework

Model Zoo

NYUDv2 dataset

Model backbone mIoU Download
GeminiFusion MiT-B3 56.8 model
GeminiFusion MiT-B5 57.7 model
GeminiFusion swin_tiny 52.2 model
GeminiFusion swin-small 55.0 model
GeminiFusion swin-large-224 58.8 model
GeminiFusion swin-large-384 60.2 model
GeminiFusion swin-large-384 +FineTune from SUN 300eps 60.9 model

SUN RGBD dataset

Model backbone mIoU Download
GeminiFusion MiT-B3 52.7 model
GeminiFusion MiT-B5 53.3 model
GeminiFusion swin_tiny 50.2 model
GeminiFusion swin-large-384 54.8 model

Installation

We build our GeminiFusion on the TokenFusion codebase, which requires no additional installation steps. If any problem about the framework, you may refer to the offical TokenFusion readme.

Most of the GeminiFusion-related code locate in the following files:

We also delete the config.py in the TokenFusion codebase since it is not used here.

Data

NYUDv2 Dataset Prapare

Please follow the data preparation instructions for NYUDv2 in TokenFusion readme. In default the data path is /cache/datasets/nyudv2, you may change it by --train-dir <your data path>.

SUN RGBD Dataset Prapare

Please download the SUN RGBD dataset follow the link in DFormer.In default the data path is /cache/datasets/sunrgbd_Dformer/SUNRGBD, you may change it by --train-dir <your data path>.

Train

NYUDv2 Training

On the NYUDv2 dataset, we follow the TokenFusion's setting, using 3 GPUs to train the GeminiFusion.

# mit-b3
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 

# mit-b5
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone mit_b5 --dataset nyudv2 -c nyudv2_mit_b5 --dpr 0.35

# swin_tiny
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone swin_tiny --dataset nyudv2 -c nyudv2_swin_tiny --dpr 0.2

# swin_small
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone swin_small --dataset nyudv2 -c nyudv2_swin_small

# swin_large
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone swin_large --dataset nyudv2 -c nyudv2_swin_large

# swin_large_window12
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c nyudv2_swin_large_window12 --dpr 0.2

# swin-large-384+FineTune from SUN 300eps
# swin-large-384.pth.tar should be downloaded by our link or trained by yourself
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c swin_large_window12_finetune_dpr0.15_100+200+100 \
 --dpr 0.15 --num-epoch 100 200 100 --is_pretrain_finetune --resume ./swin-large-384.pth.tar

SUN RGBD Training

On the SUN RGBD dataset, we use 4 GPUs to train the GeminiFusion.

# mit-b3
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4  --use_env main.py --backbone mit_b3 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b3

# mit-b5
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4  --use_env main.py --backbone mit_b5 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b5 --weight_decay 0.05

# swin_tiny
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4  --use_env main.py --backbone swin_tiny --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_tiny

# swin_large_window12
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4  --use_env main.py --backbone swin_large_window12 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_large_window12

Test

To evaluate checkpoints, you need to add --eval --resume <checkpoint path> after the training script.

For example, on the NYUDv2 dataset, the training script for GeminiFusion with mit-b3 backbone is:

CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3

To evaluate the trained or downloaded checkpoint, the eval script is:

CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3  --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 --eval --resume mit-b3.pth.tar

Citation

If you find this work useful for your research, please cite our paper:

@misc{jia2024geminifusion,
      title={GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer}, 
      author={Ding Jia and Jianyuan Guo and Kai Han and Han Wu and Chao Zhang and Chang Xu and Xinghao Chen},
      year={2024},
      eprint={2406.01210},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

Part of our code is based on the open-source project TokenFusion.