diff --git a/README.md b/README.md index 3f8d42c..d48d74c 100644 --- a/README.md +++ b/README.md @@ -235,14 +235,18 @@ In addition to shared hyperparameters such as `lr`, `weight_decay`, `batch_size` In order to make a fair comparison across different TTA algorithms, we make reasonable modifications to these algorithms, which may induce inconsistency with their official implementation. --> ## Pretraining -In `pretrain`, we provide an improved pretraining script based on what we used in our project, which can be used to pretrain the model on all of benchmark datasets used in our paper except ImageNet. Meanwhile, in this [link](https://drive.google.com/drive/folders/1KBbcNB6KR5Fqi6iueASLb85LxpkraX3K?usp=sharing), we release a set of checkpoints pretrained on the in-distribution TTAB datasets. These pre-trained models were used to benchmark baselines in our paper. Note that we adopt self-supervised learning with a rotation prediction task to train the baseline model in our paper for a fair comparison. In practice, please feel free to choose whatever pre-training methods you prefer, but please pay attention to the setup of TTA methods. +In `pretrain`, we provide an improved pretraining script based on what we used in our project, which can be used to pretrain the model on all of benchmark datasets used in our paper except ImageNet. Meanwhile, in this [link](https://drive.google.com/drive/folders/1KBbcNB6KR5Fqi6iueASLb85LxpkraX3K?usp=sharing), we release a set of checkpoints pretrained on the in-distribution TTAB datasets. These pre-trained models were used to benchmark baselines in our paper. Note that we adopt self-supervised learning with a rotation prediction task to train the baseline model in our paper for a fair comparison. `GroupNorm` is enabled for the base model by specifying values for the `--group_norm` argument. In practice, please feel free to choose whatever pre-training methods you prefer, but please pay attention to the setup of TTA methods. ```py +# BatchNorm python ssl_pretrain.py --data-name cifar10 --model-name resnet26 python ssl_pretrain.py --data-name cifar100 --model-name resnet26 python ssl_pretrain.py --data-name officehome_art --model-name resnet50 --entry-of-shared-layers layer3 --use-ls --lr 1e-2 --weight-decay 1e-4 python ssl_pretrain.py --data-name pacs_art --model-name resnet50 --entry-of-shared-layers layer3 --use-ls --lr 1e-2 --weight-decay 1e-4 python ssl_pretrain.py --data-name waterbirds --model-name resnet50 --entry-of-shared-layers layer3 --lr 1e-3 --weight-decay 1e-4 python ssl_pretrain.py --data-name coloredmnist --model-name resnet18 --entry-of-shared-layers layer3 --lr 1e-3 --weight-decay 1e-4 + +# GroupNorm +python ssl_pretrain.py --data-name cifar10 --model-name resnet26 --group_norm 8 ```