-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ffee2e2
Showing
89 changed files
with
16,542 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Test-Time Adaptation Benchmark (TTAB) | ||
|
||
## Overview | ||
TTAB is a benchmark for standardizing and comprehensively evaluating Test-time Adaptation algorithms on a diverse array of distribution shifts. | ||
|
||
The TTAB package contains: | ||
1. Data loaders that automatically handle data processing and splitting to cover multiple significant evaluation settings considered in prior work. | ||
2. Unified dataset evaluators that standardize model evaluation for each dataset and setting. | ||
3. Multiple representative Test-time Adaptation (TTA) algorithms. | ||
|
||
In addition, the example scripts contain default models, optimizers, and evaluation code. | ||
New algorithms can be easily added and run on all of the TTAB datasets. | ||
|
||
## Installation | ||
To run a baseline test, please prepare the relevant pre-trained checkpoints for the base model and place them in `pretrain/ckpt/`. | ||
### Requirements | ||
The TTAB package depends on the following requirements: | ||
|
||
- numpy>=1.21.5 | ||
- pandas>=1.1.5 | ||
- pillow>=9.0.1 | ||
- pytz>=2021.3 | ||
- torch>=1.7.1 | ||
- torchvision>=0.8.2 | ||
- timm>=0.6.11 | ||
- scikit-learn>=1.0.3 | ||
- scipy>=1.7.3 | ||
- tqdm>=4.56.2 | ||
|
||
## Datasets | ||
Distribution shift occurs when the test distribution differs from the training distribution, and it can considerably degrade performance of machine learning models deployed in the real world. The form of distribution shifts differs greatly across varying applications in practice. In TTAB, we collect 10 datasets and systematically sort them into 5 types of distribution shifts: | ||
- Covariate Shift | ||
- Natural Shift | ||
- Domain Generalization | ||
- Label Shift | ||
- Spurious Correlation Shift | ||
|
||
![TTAB -- Dataset Description](./figs/overview%20of%20datasets.jpg) | ||
<!-- | Dataset | Types of distribution shift | Access to the dataset | | ||
| ----------- | ---------------------------- | ---------------------------------------------------------------------------- | | ||
| CIFAR10-C | Covariate shift | [link](https://zenodo.org/record/2535967#.Y_F1DXbMI2w) | | ||
| CIFAR10.1 | Natural shift | [link](https://github.com/modestyachts/CIFAR-10.1/tree/master/datasets) | | ||
| OfficeHome | Domain Generalization | [link](https://www.hemanthdv.org/officeHomeDataset.html) | | ||
| PACS | Domain Generalization | [link](https://dali-dl.github.io/project_iccv2017.html) | | ||
| Waterbirds | Spurious correlation | [link](https://github.com/kohpangwei/group_DRO) | | ||
| ColoredMNIST| Spurious correlation | torchvision or [link](http://yann.lecun.com/exdb/mnist/) | --> | ||
|
||
## Using the TTAB package | ||
|
||
The TTAB package provides a simple, standardized interface for all TTA algorithms and datasets in the benchmark. This short Python snippet covers all of the steps of getting started with a user-customizable configuration, including the choice of TTA algorithms, datasets, base models, model selection methods, experimental setups, evaluation scenarios (we will discuss evaluation scenarios in more detail in [Scenario](#scenario)) and protocols. | ||
|
||
```py | ||
config, scenario = configs_utils.config_hparams(config=init_config) | ||
|
||
# Dataset | ||
test_data_cls = define_dataset.ConstructTestDataset(config=config) | ||
test_loader = test_data_cls.construct_test_loader(scenario=scenario) | ||
|
||
# Base model. | ||
model = define_model(config=config) | ||
load_pretrained_model(config=config, model=model) | ||
|
||
# Algorithms. | ||
model_adaptation_cls = get_model_adaptation_method( | ||
adaptation_name=scenario.model_adaptation_method | ||
)(meta_conf=config, model=model) | ||
model_selection_cls = get_model_selection_method(selection_name=scenario.model_selection_method)( | ||
meta_conf=config, model=model | ||
) | ||
|
||
# Evaluate. | ||
benchmark = Benchmark( | ||
scenario=scenario, | ||
model_adaptation_cls=model_adaptation_cls, | ||
model_selection_cls=model_selection_cls, | ||
test_loader=test_loader, | ||
meta_conf=config, | ||
) | ||
benchmark.eval() | ||
``` | ||
|
||
### Data loading | ||
For evaluation, the TTAB package provides two types of dataset objects. The standard dataset object stores data, labels and indices as well as several APIs to support high-level manipulation, such as mixing the source and target domains. The standard dataset object serves common evaluation metrics like Top-1 accuracy and cross-entropy. | ||
|
||
To support other metrics, such as worst-group accuracy, for more robust evaluation, we provide a group-wise dataset object that records additional group information. | ||
|
||
To provide a more seamless user experience, we have designed a unified data loader that supports all dataset objects. To load data in TTAB, simply run the following command with `config` and `scenario` as inputs. | ||
|
||
```py | ||
test_data_cls = define_dataset.ConstructTestDataset(config=config) | ||
test_loader = test_data_cls.construct_test_loader(scenario=scenario) | ||
``` | ||
|
||
### Scenario | ||
In the scenario section, we outline all relevant parameters for defining a distribution shift problem in practice, such as `test_domain` and `test_case`. In the `test_domain`, we specify the implicit $\mathcal{P}(a^{1:K})$ and selected sampling strategy. `test_case` determines how we organize the existing dataset corresponding to `test_domain` into a data stream that will be fed to TTA methods. Besides, we also define the model architecture, TTA method, and model selection method that we will use for the defined distribution shift problem. | ||
|
||
Here, we present an example of `scenario`. Please feel free to suggest a new `scenario` for your research. | ||
|
||
```py | ||
"S1": Scenario( | ||
task="classification", | ||
model_name="resnet26", | ||
model_adaptation_method="tent", | ||
model_selection_method="last_iterate", | ||
base_data_name="cifar10", | ||
test_domains=[ | ||
TestDomain( | ||
base_data_name="cifar10", | ||
data_name="cifar10_c_deterministic-gaussian_noise-5", | ||
shift_type="synthetic", | ||
shift_property=SyntheticShiftProperty( | ||
shift_degree=5, | ||
shift_name="gaussian_noise", | ||
version="deterministic", | ||
has_shift=True, | ||
), | ||
domain_sampling_name="uniform", | ||
domain_sampling_value=None, | ||
domain_sampling_ratio=1.0, | ||
) | ||
], | ||
test_case=TestCase( | ||
inter_domain=HomogeneousNoMixture(has_mixture=False), | ||
batch_size=64, | ||
data_wise="batch_wise", | ||
offline_pre_adapt=False, | ||
episodic=False, | ||
intra_domain_shuffle=True, | ||
), | ||
), | ||
``` | ||
|
||
|
||
## Using the example scripts | ||
We provide an example script that can be used to adapt distribution shifts on the TTAB datasets. | ||
|
||
```bash | ||
python run_exp.py | ||
``` | ||
|
||
Currently, before using the example script, you need to manually set up the `args` object in the `parameters.py`. This script is configured to use the default base model, dataset, evaluation protocol and reasonable hyperparameters. | ||
|
||
<!-- ## Algorithms | ||
In the `ttab/model_adaptation` folder, we provide implementations of the TTA algorithms benchmarked in our paper. We use unified setups for the base model, datasets, hyperparameters, and evaluators, so new algorithms can be easily added and run on all of the TTAB datasets. | ||
In addition to shared hyperparameters such as `lr`, `weight_decay`, `batch_size`, and `optimizer`, the scripts also take in command line arguments for algorithm-specific hyperparameters. | ||
| Algorithm | Venue | Adjust pretraining | Access to source domain | Reuse test data | Coupled w/ BatchNorm | Resetting model | Optimizer | | ||
|:----------------------------------------:|:------------:|:------------------:|:-----------------------:|:---------------:|:--------------------:|:---------------:|:-----------:| | ||
| [SHOT](https://arxiv.org/abs/2002.08546) | ICML 2020 | ✗ | ✗ | ✓ | ✗ | ✗ | SGD | | ||
| [TTT](https://arxiv.org/abs/1909.13231) | ICML 2020 | ✓ | ✗ | ✗ | ✗ | ✗ | SGD | | ||
| [BN_Adapt](https://arxiv.org/abs/2006.16971) | NeurIPS 2020 | ✗ | ✗ | ✗ | ✓ | ✗ | - | | ||
| [TENT](https://arxiv.org/abs/2006.10726) | ICLR 2021 | ✗ | ✗ | ✗ | ✓ | ✗ | Adam & SGDm | | ||
| [T3A](https://openreview.net/forum?id=e_yvNqkJKAW) | NeurIPS 2021 | ✗ | ✗ | ✗ | ✗ | ✗ | - | | ||
| [Conjugate PL](http://arxiv.org/abs/2207.09640) | NeurIPS 2022 | ✗ | ✗ | ✗ | ✓ | ✗ | Adam | | ||
| [MEMO](https://arxiv.org/abs/2110.09506) | NeurIPS 2022 | ✗ | ✗ | ✗ | ✗ | ✓ | SGD | | ||
| [NOTE](https://arxiv.org/abs/2208.05117) | NeurIPS 2022 | ✓ | ✗ | ✗ | ✓ | ✗ | Adam | | ||
| [SAR](https://openreview.net/pdf?id=g2YraF75Tj) | ICLR 2023 | ✗ | ✗ | ✗ | ✓ | ✗ | SAM | | ||
In order to make a fair comparison across different TTA algorithms, we make reasonable modifications to these algorithms, which may induce inconsistency with their official implementation. --> | ||
|
||
## Pretraining | ||
In this [link](https://drive.google.com/drive/folders/1ALNIYnnTJwqP80n9pEjSWtb_UdbcrsVi?usp=sharing), we provide a set of scripts that can be used to pre-train models on the in-distribution TTAB datasets. These pre-trained models were used to benchmark baselines in our paper. Note that we adopt self-supervised learning with a rotation prediction task to train the baseline model in our paper for a fair comparison. In practice, please feel free to choose whatever pre-training methods you prefer, but please pay attention to the setup of TTA methods. | ||
<!-- ## Citing TTAB --> | ||
|
||
## Acknowledgements | ||
|
||
|
||
|
||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# -*- coding: utf-8 -*- | ||
import argparse | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
# define test evaluation info. | ||
parser.add_argument("--root_path", default="./data/logs", type=str) | ||
parser.add_argument("--data_path", default="./datasets", type=str) | ||
parser.add_argument( | ||
"--ckpt_path", | ||
default="./pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", | ||
type=str, | ||
) | ||
parser.add_argument("--seed", default=2022, type=int) | ||
parser.add_argument("--device", default="cuda:0", type=str) | ||
parser.add_argument("--num_cpus", default=2, type=int) | ||
|
||
# define the task & model & adaptation & selection method. | ||
parser.add_argument("--model_name", default="resnet26", type=str) | ||
parser.add_argument("--group_norm_num_groups", default=None, type=int) | ||
parser.add_argument( | ||
"--model_adaptation_method", | ||
default="tent", | ||
choices=[ | ||
"no_adaptation", | ||
"tent", | ||
"bn_adapt", | ||
"memo", | ||
"shot", | ||
"t3a", | ||
"ttt", | ||
"note", | ||
"sar", | ||
"conjugate_pl", | ||
"cotta", | ||
"eata", | ||
], | ||
type=str, | ||
) | ||
parser.add_argument( | ||
"--model_selection_method", | ||
default="last_iterate", | ||
choices=["last_iterate", "oracle_model_selection"], | ||
type=str, | ||
) | ||
parser.add_argument("--task", default="classification", type=str) | ||
|
||
# define the test scenario. | ||
parser.add_argument("--test_scenario", default=None, type=str) | ||
parser.add_argument( | ||
"--base_data_name", | ||
default="cifar10", | ||
choices=[ | ||
"cifar10", | ||
"cifar100", | ||
"imagenet", | ||
"officehome", | ||
"pacs", | ||
"coloredmnist", | ||
"waterbirds", | ||
], | ||
type=str, | ||
) | ||
parser.add_argument("--src_data_name", default="cifar10", type=str) | ||
parser.add_argument( | ||
"--data_names", default="cifar10_c_deterministic-gaussian_noise-5", type=str | ||
) | ||
parser.add_argument( | ||
"--data_wise", | ||
default="batch_wise", | ||
choices=["batch_wise", "sample_wise"], | ||
type=str, | ||
) | ||
parser.add_argument("--batch_size", default=64, type=int) | ||
parser.add_argument("--lr", default=1e-3, type=float) | ||
parser.add_argument("--n_train_steps", default=1, type=int) | ||
parser.add_argument("--offline_pre_adapt", default=False, type=str2bool) | ||
parser.add_argument("--episodic", default=False, type=str2bool) | ||
parser.add_argument("--intra_domain_shuffle", default=True, type=str2bool) | ||
parser.add_argument( | ||
"--inter_domain", | ||
default="HomogeneousNoMixture", | ||
choices=[ | ||
"HomogeneousNoMixture", | ||
"HeterogeneousNoMixture", | ||
"InOutMixture", | ||
"CrossMixture", | ||
], | ||
type=str, | ||
) | ||
# Test domain | ||
parser.add_argument("--domain_sampling_name", default="uniform", type=str) | ||
parser.add_argument("--domain_sampling_ratio", default=1.0, type=float) | ||
# HeterogeneousNoMixture | ||
parser.add_argument("--non_iid_pattern", default="class_wise_over_domain", type=str) | ||
parser.add_argument("--non_iid_ness", default=0.1, type=float) | ||
# for evaluation. | ||
# label shift | ||
parser.add_argument( | ||
"--label_shift_param", | ||
help="parameter to control the severity of label shift", | ||
default=None, | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--data_size", | ||
help="parameter to control the size of dataset", | ||
default=None, | ||
type=int, | ||
) | ||
# optimal model selection | ||
parser.add_argument( | ||
"--step_ratios", | ||
nargs="+", | ||
default=[0.1, 0.3, 0.5, 0.75], | ||
help="ratios used to control adaptation step length", | ||
type=float, | ||
) | ||
parser.add_argument("--step_ratio", default=None, type=float) | ||
# time-varying | ||
parser.add_argument("--stochastic_restore_model", default=False, type=str2bool) | ||
parser.add_argument("--restore_prob", default=0.01, type=float) | ||
parser.add_argument("--fishers", default=False, type=str2bool) | ||
parser.add_argument( | ||
"--fisher_size", | ||
default=5000, | ||
type=int, | ||
help="number of samples to compute fisher information matrix.", | ||
) | ||
parser.add_argument( | ||
"--fisher_alpha", | ||
type=float, | ||
default=1.5, | ||
help="the trade-off between entropy and regularization loss", | ||
) | ||
# method-wise hparams | ||
parser.add_argument( | ||
"--aug_size", | ||
default=32, | ||
help="number of per-image augmentation operations in memo and ttt", | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"--entry_of_shared_layers", | ||
default=None, | ||
help="the split position of auxiliary head. Only used in TTT.", | ||
) | ||
# metrics | ||
parser.add_argument( | ||
"--record_preadapted_perf", | ||
default=False, | ||
help="record performance on the local batch prior to implementing test-time adaptation.", | ||
type=str2bool, | ||
) | ||
# misc | ||
parser.add_argument( | ||
"--grad_checkpoint", | ||
default=False, | ||
help="Trade computation for gpu space.", | ||
type=str2bool, | ||
) | ||
parser.add_argument("--debug", default=False, help="Display logs.", type=str2bool) | ||
|
||
# parse conf. | ||
conf = parser.parse_args() | ||
return conf | ||
|
||
|
||
def str2bool(v): | ||
if v.lower() in ("yes", "true", "t", "y", "1"): | ||
return True | ||
elif v.lower() in ("no", "false", "f", "n", "0"): | ||
return False | ||
else: | ||
raise ValueError("Boolean value expected.") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() |
Oops, something went wrong.