ICD-LM: Configuring Vision-Language In-Context Demonstrations by Language Modeling
git clone
conda create -n icdlm python=3.10
conda activate icdlm
pip install -r requirements.txt
# install the openicl package
mkdir requirements_repo
cd requirements_repo
# for anonymous submit, it will fix in Formal version
git clone https://github.com/ForJadeForest/OpenICL.git
cd OpenICL
git checkout -b coco_caption origin/coco_caption
pip install -e ./
cd ../..
You should set the Environment varibles for dataset path and openflamingo path:
CHECKPOINT_PATH="./openflamingo" # the checkpoint path you want to save
COCO_PATH="/path/to/mscoco"
VQAV2_PATH="/path/to/vqav2"
RESULT_DIR="/path/to/result" # the dir to save result(checkpoint, inference metric, cache...)
The flamingo checkpoint path will download automatically.
We use the mscoco train dataset to generate. So you should prepare the mscoco2017 or mscoco2014
|-- mscoco
| |
| |- mscoco2017
| | |
| | |- train2017
| | |- val2017
| | |- annotations
| | |
| | |- captions_train2017.json
| | |- captions_val2017.json
| |- mscoco2014
| |
| |- train2014
| |- val2014
| |- annotations
| |
| |- captions_train2014.json
| |- captions_val2014.json
We use the VQAV2 train dataset to generate the good ICD Sequences.
So you should prepare the VQAV2 dataset or if you can download datasets from huggingface you can use configs/dataset/vqav2_online.yaml
.
# For download the vqav2 dataset:
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip -O /path/to/vqav2/
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip -O /path/to/vqav2/
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip -O /path/to/vqav2/
wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip -O /path/to/vqav2/
cd /path/to/vqav2/
unzip v2_Annotations_Train_mscoco.zip
unzip v2_Annotations_Val_mscoco.zip
unzip v2_Questions_Train_mscoco.zip
unzip v2_Questions_Val_mscoco.zip
# for preprepare the dataset.
python src/dataset_module/preprocess/vqav2_hf.py --root_path /path/to/vqav2/
Then, set the VQAV2_PATH
environment variable in .env
. (If you use vaqv2_local
as dataset)
# for coco2017 image captioning
bash scripts/generate_data.sh caption coco2017 "[0,1,2,3]"
# for vqav2
bash scripts/generate_data.sh vqa vqav2_local "[0,1,2,3]"
# We support vqav2 dataset of hf. It will download the dataset automatically.
bash scripts/generate_data.sh vqa vqav2_online "[0,1,2,3]"
# for coco2017 image captioning
bash scripts/train_icd_lm.sh caption coco2017 1 query_img_icd_img_text
# for vqav2
bash scripts/train_icd_lm.sh vqa vqav2_local 1 query_img_text_icd_img_text
# or use hf vqav2 dataset
bash scripts/train_icd_lm.sh vqa vqav2_online 1 query_img_text_icd_img_text
python train.py
Args:
train
: Options arequery_img_icd_idx
,query_img_icd_img_text
,query_img_icd_img
,query_img_icd_text
,query_img_text_icd_img_text
. 'img' after 'query' indicates the addition of image information to the query sample. 'text' after 'query' indicates the addition of text information to the query sample. The same applies to 'icd'.dataset
: Defines the dataset for In-context Learning. For caption tasks, you can choose either coco2017 or coco2014; for VQA tasks, choose between vqav2_local or vqav2_online. This parameter also includes the dataset path and other relevant information.task
: Options arevqa
orcaption
, configuring parameters related to promptdata_files
: Specifies the names of the JSON data files generated in the first step.trainer_args
: the lightning triner argslr
: learning rateex_name
: Name of the current experiment, which is also the name of the folder for saving experimental results.seed
: Sets the seed for random number generation.
# for coco2017 image captioning
bash scripts/inference.sh caption coco2017 0 query_img_icd_img_text
# for vqav2
bash scripts/inference.sh vqa vqav2_local 0 query_img_text_icd_img_text
# or use hf vqav2 dataset
bash scripts/inference.sh vqa vqav2_online 0 query_img_text_icd_img_text
# You can use a vqav2 sub-val set to validate the performance, which only contain 1w samples.
bash scripts/inference.sh vqa vqav2_local_sub 0 query_img_text_icd_img_text
python inference_flamingo.py
If test the icd_lm model, you should set:
train
: Options arequery_img_icd_idx
,query_img_icd_img_text
,query_img_icd_img
,query_img_icd_text
,query_img_text_icd_img_text
. 'img' after 'query' indicates the addition of image information to the query sample. 'text' after 'query' indicates the addition of text information to the query sample. The same applies to 'icd'.icd_lm_path
: Path to the model checkpoint.test_icd_lm
: Set to true.random_order_icd_lm_iocd
: If setTrue
, the icd configuration generated by ICD-LM will be randomly shuffled.default_cpk_key
: The checkpoint key word. You can set it tolast
ormin_loss
ex_name
: Name of the current experiment, which is also the name of the folder for saving inference results.
Other args;
dataset
: Defines the dataset for In-context Learning. For caption tasks, choose either coco2017 or coco2014; for VQA tasks, choose between vqav2_local or vqav2_online.task
: Options arevqa
orcaption
, configuring parameters related to promptflamingo
: Flamingo model version, options includeflamingo_3B
,flamingo_9B
.index_data_num
: Number of items in the ICD training set, -1 for all.test_data_num
: Number of items in the test set, -1 for all.inference_bs
:Batch size for inference. For a 3090 with 24G of memory, a setting of 4 is feasible for 16 shots.shot_num_list
: The shot num you want to test.
Test Retrieval-based Method:
test_random
: Use RS as the retrieval method.test_t2t
: Use STTR as the retrieval method.test_i2t
: Use SITR as the retrieval method.test_i2i
: Use SIIR as the retrieval method.mmtopk_clip_name
: CLIP model name to calculate the similarity.mmtopk_reversed_order
: If setTrue
, the rightmost ICD is the most similar, while setFalse
, the leftmost ICD is the most similar.