Skip to content

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

License

Notifications You must be signed in to change notification settings

scale-lab/AdaMTL

Repository files navigation

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

Introduction

This is the official implementation of the paper: AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning.

This repository provides a Python-based implementation of the adaptive multi-task learning (MTL) approach proposed in the paper. Our method is designed to improve efficiency in multi-task learning by adapting inference based on input, reducing computational requirements and improving performance across multiple tasks. The repository is based upon Swin-Transformer and uses some modules from Multi-Task-Learning-PyTorch.

How to Run

To run the AdaMTL code, follow these steps:

  1. Clone the repository

    git clone https://github.com/scale-lab/AdaMTL.git
    cd AdaMTL
  2. Install the prerequisites

    • Install PyTorch>=1.12.0 and torchvision>=0.13.0 with CUDA>=11.6
    • Install dependencies: pip install -r requirements.txt
  3. Run the code

    Stage 1: Training the backbone: python main.py --cfg configs/swin/<swin variant>.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=1000 --resume-backbone <path to swin weights>

    Stage 2: Controller pretraining: python main.py --cfg configs/ada_swin/<swin variant>_<tag/taw>_pretrain.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=100 --resume <path to the weights generated from Stage 1>

    Stage 3: MTL model training: python main.py --cfg configs/ada_swin/<swin variant>_<tag/taw>.yaml --pascal <path to pascal database> --tasks semseg,normals,sal,human_parts --batch-size <batch size> --ckpt-freq=20 --epoch=300 --resume <path to the weights generated from Stage 2>

    Swin variants and their weights can be found at the official Swin Transformer repository.

    The outputs will be saved in output/ folder unless overridden by the argument --output.

Authorship

Since the release commit is squashed, the GitHub contributors tab doesn't reflect the authors' contributions. The following authors contributed equally to this codebase:

Citation

If you find AdaMTL helpful in your research, please cite our paper:

@inproceedings{neseem2023adamtl,
  title={AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning},
  author={Neseem, Marina and Agiza, Ahmed and Reda, Sherief},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={4729--4738},
  year={2023}
}

License

MIT License. See LICENSE file

About

AdaMTL: Adaptive Input-dependent Inference for Efficient Multi-Task Learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published