Skip to content

Active Learning Helps Pretrained Models Learn the Intended Task ( by Alex Tamkin, Dat Nguyen, Salil Deshpande, Jesse Mu, and Noah Goodman


Notifications You must be signed in to change notification settings


Repository files navigation

This code is the official PyTorch implementation of the paper Active Learning Helps Pretrained Models Learn the Intended Task (

This project mainly uses the Google BiT models. We reuse a lot of code and settings from the official BiT repository ( So far, only the vision tasks are available. The NLP portion of the code will be released later.


See requirements.txt for the list of required packages. They can be installed by

conda install --file requirements.txt


pip install -r requirements.txt

Note: The packages torch-scatter and torch-geometric, which are required for wilds, might require a manual installation. See for more instructions.


All datasets need to be loaded using the utils/datasets/ module. The function load in this module contains the list of available datasets. To add a new dataset, edit this function and add an entry to the known_datasets dictionary in utils/datasets/ .

Below is some information about the default datasets. Note that some of them need to be downloaded manually.


This dataset needs to be downloaded manually and is available at

utils/dataset/waterbirds contains a slightly modified version of the original script (, which can be used to generate variants with different percentages of mis-matched background.


Created from the GQA dataset ( The file utils/datasets/treeperson/metadata.csv contains the list of images chosen from the GQA dataset, together with their labels and splits.

This dataset needs to be downloaded manually. Place the GQA images in /some/path/images/, then copy utils/datasets/treeperson/metadata.csv to /some/path/metadata.csv .


This dataset will be downloaded automatically if necessary. It is also available at


(1) utils/datasets/ is only compatible with WILDS v1.1 and 1.2. This is because WILDS v2.0 changes the datasets' split dictionaries. A small modification to the function load_wilds_datasets in utils/datasets/load would be needed to accommodate these changes.

(2) For this dataset, it might take a while for the run scripts above to build a seed set.


Weights: The model weight file, if required, should be downloaded to the main directory. The Google BiT model weights are available at the official repository linked above. For example, to download the BiT-M-R50x1 model weights, run


Run: To train a model and print training progress, validation accuracy, etc, run python3 -m run (flags). For example:

python3 -m run --name test_run --model BiT-M-R50x1 --logdir /path/to/dir --dataset waterbirds --datadir /datasets/waterbird_complete95_forest2water2/ --target_attr bird --valid_splits out_sample

For the list of flags, either run

python3 -m run -h

or see models/

Quick start: It might be more convenient to run experiments from a script. Some sample scripted runs are provided in sample_run_scripts. To use them, change the dataset_path and logdir_base variables to the appropriate paths, then run of the following:

python3 -m sample_run_scripts.waterbirds
python3 -m sample_run_scripts.treeperson
python3 -m sample_run_scripts.iwildcam


Active Learning Helps Pretrained Models Learn the Intended Task ( by Alex Tamkin, Dat Nguyen, Salil Deshpande, Jesse Mu, and Noah Goodman







No releases published


No packages published
