This repository contains the dataset and code accompanying the CVPR 2022 paper "Learning to Learn and Remember Super Long Multi-Domain Task Sequence" (Oral)
Domain-Aware SDML | Domain-Agnostic SDML |
---|---|
- Python 3.7
- PyTorch 1.8.0
- torchmeta 1.7.0
- numpy 1.20.3
- tqdm
-
Install the above packages requirements
-
Download ten datasets ( ['Quickdraw', 'Aircraft', 'CUB', 'MiniImagenet', 'Omniglot', 'Plantae', 'Electronic', 'CIFARFS', 'Fungi', 'Necessities']) from google drive here and put the dataset folder in the root directory of this project
Training the meta-learning models for sequential arriving datasets
python train_sequence.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with Meta Experience Replay (MER) for sequential arriving datasets
python train_MER.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with Averaged GEM (AGEM) for sequential arriving datasets
python train_AGEM.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with HAT for sequential arriving datasets
python train_HAT.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with UCB for sequential arriving datasets
python train_UCB.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with our methods for sequential arriving datasets
python train_domain_aware.py --data_path 'data/path'
Training the meta-learning models (Prototypical Network) with online domain shift detection for sequential arriving datasets
python train_domain_shift_detection.py --data_path 'data/path'
@InProceedings{Wang_2022_CVPR,
author = {Wang, Zhenyi and Shen, Li and Duan, Tiehang and Zhan, Donglin and Fang, Le and Gao, Mingchen},
title = {Learning To Learn and Remember Super Long Multi-Domain Task Sequence},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {7982-7992}
}
Some codes of Bayesian online changepoint detection are from link