This repository is the official implementation of Meta-DMoE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts.
The code was tested on python3.7 and CUDA10.1.
We recommend using conda environment to setup all required dependencies:
conda env create -f environment.yml
conda activate dmoe
If you have any problem with the above command, you can also install them by pip install -r requirements.txt
.
Either of these commands will automatically install all required dependecies except for the torch-scatter
and torch-geometric
packages , which require a quick manual install.
We provide the training script for the following 4 datasets from the WILDS benchmark: iwildcam
, camelyon
, rxrx1
, and FMoW
. To train the models in the paper, run the following commands:
python run.py --dataset <dataset>
The data will be automatically downloaded to the data folder.
Although we are not able to provide Multi-GPU support for meta-training at this point, you could still consider training the expert models in a distributed manner by opening multiple terminals and running:
python train_single_expert.py --dataset <dataset> --data_dir <path to data folder> --gpu <which GPU> --expert_idx <which expert in {0, ... , Num_Experts - 1}>
and add --load_trained_experts
flag when running run.py
.
To evaluate trained models, run:
python run.py --dataset <dataset> --data_dir <path to data folder> --test
To reproduce the results reported in Table 1 in our paper, you can download pretrained models here and extract to model/<dataset>
folder. Note that due to the size limit of cloud storage, we only uploaded checkpoints from one random seed per dataset, while the results reported in the table are aggregated across several random seeds.
If you find this codebase useful in your research, consider citing:
@inproceedings{
zhong2022metadmoe,
title={Meta-{DM}oE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts},
author={Tao Zhong and Zhixiang Chi and Li Gu and Yang Wang and Yuanhao Yu and Jin Tang},
booktitle={Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS)},
year={2022}
}