This repository provides example code for the BMVC 2020 paper Automated Search for Resource-Efficient Branched Multi-Task Networks.
BMTAS is a principled method to automatically find branched multi-task networks for a given set of tasks. The goal is to obtain models which are both high-performing and efficient during inference. This is achieved by optimizing a supergraph through differentiable neural architecture search. The desired performance vs. efficiency trade-off is controlled via a proxyless, resource-aware loss.
We provide example code for searching branched networks based on a MobileNetV2 backbone and DeepLabv3+ head on the PASCAL-Context dataset. The dataset provides labels for five dense prediction tasks: semantic segmentation, human parts segmentation, saliency estimation, surface normal estimation and edge detection.
The code was run in a conda
environment, using Python 3.7 with the following packages:
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
conda install scikit-image==0.16.2 tensorboard==2.2.1
conda install opencv==4.4.0 -c conda-forge
To start an architecture search, use for example:
python main_search.py --data_root . --tasks semseg,human_parts,sal,normals,edge --resource_loss_weight 0.05
On the first run, this will also download the dataset.
To train the branched network found via architecture search, supply the path to the corresponding branch_config.json
:
python main_branched.py --configuration </path/to/branch_config> --data_root . --tasks semseg,human_parts,sal,normals,edge
Due to GPU memory constraints, multi-GPU training may be required. By default, branched network training is distributed across all available GPUs on the node.
To evaluate a trained branched network, supply the path to the corresponding checkpoint.pth
:
python main_test.py --model_path </path/to/checkpoint> --data_root .
Note that the evaluation of edge detection is disabled, since the MATLAB-based SEISM repository was used for obtaining the optimal dataset F-measure scores. Instead, the edge predictions are simply saved on the disk in this code.
There are some minor differences between this version and the original code used for the paper:
- We use the Adam optimizer also for weights during the search - this way, we can avoid using a learning rate scheduler in the bilevel optimization.
- The learning rate is kept equal for all blocks when training the branched network, not adapted depending on the degree of operator sharing.
If you found this code useful in your research, please consider citing the paper:
@InProceedings{bruggemann2020automated,
Title = {Automated Search for Resource-Efficient Branched Multi-Task Networks},
Author = {Bruggemann, David and Kanakis, Menelaos and Georgoulis, Stamatios and Van Gool, Luc},
Booktitle = {BMVC},
Year = {2020}
}
The dataset and evaluation metrics are taken from the ASTMT repository.
For questions about the code or paper, feel free to contact me (send email).