This repository provides source code for the ICCV 2021 paper Exploring Relational Context for Multi-Task Dense Prediction. The code is organized using PyTorch Lightning.
ATRC is an attention-driven module to refine task-specific dense predictions by capturing cross-task contexts. Through Neural Architecture Search (NAS), ATRC selects contexts for multi-modal distillation based on the source-target tasks' relation. We investigate four context types: global, local, t-label and s-label (as well as the option to sever the cross-task connection). In the figure above, each CP block handles one source-target task connection.
We provide code for searching ATRC configurations and training various multi-modal distillation networks on the NYUD-v2 and PASCAL-Context benchmarks, based on HRNet backbones.
The code is run in a conda
environment with Python 3.8.11:
conda install pytorch==1.7.0 torchvision==0.8.1 cudatoolkit=10.1 -c pytorch
conda install pytorch-lightning==1.1.8 -c conda-forge
conda install opencv==4.4.0 -c conda-forge
conda install scikit-image==0.17.2
pip install jsonargparse[signatures]==3.17.0
NOTE: PyTorch Lightning is still going through heavy development, so make sure version 1.1.8 is used with this code to avoid issues.
To start an ATRC search on NYUD-v2 with a HRNetV2-W18-small backbone, use for example:
python ./src/main_search.py --cfg ./config/nyud/hrnet18/atrc_search.yaml --datamodule.data_dir . --trainer.gpus 2 --trainer.accelerator ddp
The path to the data directory can be customized with --datamodule.data_dir
. The data is downloaded automatically on the first run. With every validation epoch, the current ATRC configuration is saved as a atrc_genotype.json
file in the log directory.
To train ATRC distillation networks supply the path to the corresponding atrc_genotype.json
, e.g., $GENOTYPE_DIR
:
python ./src/main.py --cfg ./config/nyud/hrnet18/atrc.yaml --model.atrc_genotype_path $GENOTYPE_DIR/atrc_genotype.json --datamodule.data_dir . --trainer.gpus 1
Some genotype files can be found under genotypes/
.
Baselines can be run by selecting the config file, e.g., multi-task learning baseline:
python ./src/main.py --cfg ./config/nyud/hrnet18/baselinemt.yaml --datamodule.data_dir . --trainer.gpus 1
The evaluation of boundary detection is disabled, since the MATLAB-based SEISM repository was used for obtaining the optimal dataset F-measure scores. Instead, the boundary predictions are simply saved on the disk in this code.
NOTE: Following previous works for SEISM boundary detection evaluation, we set maxDist=0.0075 for PASCAL-Context and maxDist=0.011 for NYUD-v2.
If you find this code useful in your research, please consider citing the paper:
@InProceedings{bruggemann2020exploring,
Title = {Exploring Relational Context for Multi-Task Dense Prediction},
Author = {Bruggemann, David and Kanakis, Menelaos and Obukhov, Anton and Georgoulis, Stamatios and Van Gool, Luc},
Booktitle = {ICCV},
Year = {2021}
}
This repository is released under the MIT license. However, care should be taken to adopt appropriate licensing for third-party code in this repository. Third-party code is marked accordingly.