git clone https://github.com/EyalMichaeli/SaSPA-Aug.git
cd SaSPA-Aug
conda env create -f environment.yml
conda activate saspa
For a quick setup, we will run on the planes dataset. The dataset is downloaded automatically.
Download the planes.pth
checkpoint from Google Drive into all_utils/checkpoints/planes
run run_aug/run_aug.py
. This script generates the data and at the end filters the data. The augmentations will be saved to data/FGVC-Aircraft/fgvc-aircraft-2013b/data/aug_data/controlnet/sd_v1.5/canny/gpt-meta_class_prompt_w_sub_class_artistic_prompts_p_0.5_seed_1
together with a log file and after the generation is done, a JSON file that contains the original files paths and their respective augmentations.
Once you have the JSON file, copy its path to fgvc/trainings_scripts/consecutive_runs_aug.sh
, under the variable aug_json
. Then, run with
bash fgvc/trainings_scripts/consecutive_runs_aug.sh
That's it!
You should see your training start at <repo_path>/logs/dataset_name/
.
- Aircraft, Cars, and DTD: Downloaded automatically via torchvision to the local folder
data/<dataset_name>
. - CUB: Download from Caltech-UCSD Birds-200-2011 to
data/CUB
. - CompCars: Download from CompCars dataset page to
data/compcars
.
If the original dataset does not include a validation set, file names splits are provided in fgvc/datasets_files
and are loaded automatically.
In our experiments, we utilize Weights & Biases (wandb) for training monitoring. The traning script auto-connects to wandb. To disable this, set the DONT_WANDB variable in train.py to True.
For our filtering, we need a baseline model trained on the original dataset. We provide with pre-trained checkpoints for each dataset used in our paper in Google Drive, please either download or train a baseline model, and move the checkpoint to the folder all_utils/checkpoints/<dataset_name>
.
To create the prompts using GPT-4, follow the instructions in the paper.
The generated prompts should be in prompts_engineering/gpt_prompts
, which currently contain our generated prompts.
The generation code is located at: run_aug/run_aug.py
.
Choose a dataset and ensure that BASE_MODEL = "blip_diffusion"
and CONTROLNET = "canny"
. If you don't want to use blip_diffusion, you can use other base models such as sd_v1.5 or sd_xl-turbo (Currently it's set for sd_v1.5 because it's better for the Aircraft dataset, for all other datasets, set BASE_MODEL = "blip_diffusion"
).
The code will generate augmentations in the folder <dataset_root>/aug_data
and then will automatically generate a JSON file with the filtered augmentations.
Once you have the JSON file, copy its path to trainings_scripts/consecutive_runs_aug.sh
, under the variable aug_json
.
Make sure the correct dataset is specified in the dataset
variable and fill in the rest of the arguments (GPU ID, run_name, etc.). Currently, the appropriate arguments for training, such as augmentation ratio and traditional augmentation used, are automatically chosen in the script based on the dataset name.
After the script args are ready, run with
bash fgvc/trainings_scripts/consecutive_runs_aug.sh
That's it!
You should see your training start at <repo_path>/logs/<dataset_name>/
.
To incorporate new datasets into the project, follow these structured steps:
- Prompt Creation: Begin by generating and adding new prompts to
prompts_engineering/gpt_prompts
. - Dataset Class Development: Add a new dataset class within
all_utils/dataset_utils.py
to manage dataset-specific functionalities. - Dataset Module Implementation: Add a new Python file in the
fgvc/datasets
folder. - Dataset Config: Add a new Python file with Hyper-parameters in the
fgvc/configs
folder. - Baseline Model Training: Train a baseline model to ensure the new dataset is correctly integrated and functional. This model will also be used in the filtering process.
- Follow Standard Procedures: Proceed with the regular augmentation and training workflows as documented in Running the Code.
If you find our work useful, we welcome citations:
@inproceedings{
michaeli2024advancing,
title={Advancing Fine-Grained Classification by Structure and Subject Preserving Augmentation},
author={Eyal Michaeli and Ohad Fried},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=MNg331t8Tj}
}
- You might need to re-install PyTorch according to your server.
We extend our gratitude to the following resources for their significant contributions to our project:
- CAL Repository: Visit CAL Repo for more details.
- Diffusers Package: Learn more about the Diffusers package at Hugging Face Diffusers Documentation.