From 23e29d0e9bc872a9ac521fd486f374aeb4d221f0 Mon Sep 17 00:00:00 2001 From: Keshav Shivkumar Date: Mon, 1 May 2023 18:16:49 -0400 Subject: [PATCH 1/3] ViLT GQA --- run.py | 3 - vilt/config.py | 34 +++++- vilt/datamodules/__init__.py | 2 + vilt/datamodules/datamodule_base.py | 14 ++- vilt/datamodules/gqa_datamodule.py | 63 +++++++++++ vilt/datamodules/multitask_datamodule.py | 2 +- vilt/datasets/__init__.py | 1 + vilt/datasets/base_dataset.py | 46 ++++++-- vilt/datasets/gqa_dataset.py | 56 ++++++++++ vilt/gadgets/my_metrics.py | 27 +++++ vilt/modules/objectives.py | 78 ++++++++++++++ vilt/modules/vilt_module.py | 38 ++++++- vilt/modules/vilt_utils.py | 14 ++- vilt/utils/write_gqa.py | 128 +++++++++++++++++++++++ 14 files changed, 481 insertions(+), 25 deletions(-) create mode 100644 vilt/datamodules/gqa_datamodule.py create mode 100644 vilt/datasets/gqa_dataset.py create mode 100644 vilt/utils/write_gqa.py diff --git a/run.py b/run.py index 0719e6d..0c97b12 100644 --- a/run.py +++ b/run.py @@ -11,7 +11,6 @@ def main(_config): _config = copy.deepcopy(_config) pl.seed_everything(_config["seed"]) - dm = MTDataModule(_config, dist=True) model = ViLTransformerSS(_config) @@ -29,7 +28,6 @@ def main(_config): _config["log_dir"], name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}', ) - lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step") callbacks = [checkpoint_callback, lr_callback] @@ -44,7 +42,6 @@ def main(_config): ) max_steps = _config["max_steps"] if _config["max_steps"] is not None else None - trainer = pl.Trainer( gpus=_config["num_gpus"], num_nodes=_config["num_nodes"], diff --git a/vilt/config.py b/vilt/config.py index f86f5d7..eb9f5e3 100644 --- a/vilt/config.py +++ b/vilt/config.py @@ -2,7 +2,6 @@ ex = Experiment("ViLT") - def _loss_names(d): ret = { "itm": 0, @@ -11,6 +10,7 @@ def _loss_names(d): "vqa": 0, "nlvr2": 0, "irtr": 0, + "gqa": 0, } ret.update(d) return ret @@ -35,6 +35,7 @@ def config(): # Text Setting vqav2_label_size = 3129 + gqa_label_size = 1878 max_text_len = 40 tokenizer = "bert-base-uncased" vocab_size = 30522 @@ -77,7 +78,7 @@ def config(): num_gpus = 1 num_nodes = 1 load_path = "" - num_workers = 8 + num_workers = 4 precision = 16 @@ -179,6 +180,35 @@ def task_finetune_vqa_randaug(): val_check_interval = 0.1 lr_mult = 10 +@ex.named_config +def task_finetune_gqa(): + exp_name = "finetune_gqa" + datasets = ["gqa"] + loss_names = _loss_names({"gqa": 1}) + batch_size = 256 + max_epoch = 10 + max_steps = None + warmup_steps = 0.1 + draw_false_image = 0 + learning_rate = 1e-4 + val_check_interval = 0.1 + lr_mult = 10 + +@ex.named_config +def task_finetune_gqa_randaug(): + exp_name = "finetune_gqa_randaug" + datasets = ["gqa"] + train_transform_keys = ["pixelbert_randaug"] + loss_names = _loss_names({"gqa": 1}) + batch_size = 256 + max_epoch = 10 + max_steps = None + warmup_steps = 0.1 + draw_false_image = 0 + learning_rate = 1e-4 + val_check_interval = 0.1 + lr_mult = 10 + @ex.named_config def task_finetune_irtr_coco(): diff --git a/vilt/datamodules/__init__.py b/vilt/datamodules/__init__.py index 89d8857..6f73526 100644 --- a/vilt/datamodules/__init__.py +++ b/vilt/datamodules/__init__.py @@ -5,6 +5,7 @@ from .sbu_datamodule import SBUCaptionDataModule from .vqav2_datamodule import VQAv2DataModule from .nlvr2_datamodule import NLVR2DataModule +from .gqa_datamodule import GQADataModule _datamodules = { "vg": VisualGenomeCaptionDataModule, @@ -14,4 +15,5 @@ "sbu": SBUCaptionDataModule, "vqa": VQAv2DataModule, "nlvr2": NLVR2DataModule, + "gqa": GQADataModule, } diff --git a/vilt/datamodules/datamodule_base.py b/vilt/datamodules/datamodule_base.py index b8a3ec1..61948d9 100644 --- a/vilt/datamodules/datamodule_base.py +++ b/vilt/datamodules/datamodule_base.py @@ -1,4 +1,5 @@ import torch +import os from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -132,18 +133,15 @@ def set_test_dataset(self): image_only=self.image_only, ) - def setup(self, stage): - if not self.setup_flag: + def setup(self, stage=None): + if stage == "fit" or stage == "test" or stage is None: self.set_train_dataset() - self.set_val_dataset() - self.set_test_dataset() - self.train_dataset.tokenizer = self.tokenizer + self.set_val_dataset() self.val_dataset.tokenizer = self.tokenizer + self.set_test_dataset() self.test_dataset.tokenizer = self.tokenizer - self.setup_flag = True - def train_dataloader(self): loader = DataLoader( self.train_dataset, @@ -175,4 +173,4 @@ def test_dataloader(self): pin_memory=True, collate_fn=self.test_dataset.collate, ) - return loader + return loader \ No newline at end of file diff --git a/vilt/datamodules/gqa_datamodule.py b/vilt/datamodules/gqa_datamodule.py new file mode 100644 index 0000000..6aa8db7 --- /dev/null +++ b/vilt/datamodules/gqa_datamodule.py @@ -0,0 +1,63 @@ +from vilt.datasets import GQADataset +from .datamodule_base import BaseDataModule +from collections import defaultdict +import numpy as np + +class GQADataModule(BaseDataModule): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def dataset_cls(self): + return GQADataset + + @property + def dataset_name(self): + return "gqa" + + def setup(self, stage): + super().setup(stage) + + train_answers = self.train_dataset.table["answers"].to_pandas().tolist() + val_answers = self.val_dataset.table["answers"].to_pandas().tolist() + train_labels = self.train_dataset.table["answer_label"].to_pandas().tolist() + val_labels = self.val_dataset.table["answer_label"].to_pandas().tolist() + + all_answers = [c for c in train_answers + val_answers if c is not None] + + train_answer_tuples = [(label, answer) for labels, answers in zip(train_labels, train_answers) for label, answer in zip(labels.tolist(), answers.tolist())] + val_answer_tuples = [(label, answer) for labels, answers in zip(val_labels, val_answers) for label, answer in zip(labels.tolist(), answers.tolist())] + + train_answer2id = {answer: label for label, answer in train_answer_tuples} + val_answer2id = {answer: label for label, answer in val_answer_tuples} + # print([i for i in train_answer2id if train_answer2id[i]==2]) + # Merge train and val dictionaries, keeping the label ids from the train dictionary + self.answer2id = {**val_answer2id, **train_answer2id} + + self.num_class = len(self.answer2id) + self.id2answer = defaultdict(lambda: "unknown") + for k, v in self.answer2id.items(): + self.id2answer[v] = k + + # Print some samples from the training dataset + + # print("Training dataset samples:") + # for idx, sample in enumerate(self.train_dataset): + # if idx >= 10: + # break + # print('In GQADataModule') + # question = sample["text"] + # label = sample["gqa_label"] + # answer = self.id2answer[label] + # print(f"Question: {question}\nLabel: {label}\nAnswer: {answer}") + + # print("\nValidation dataset samples:") + # # Print some samples from the validation dataset + # for idx, sample in enumerate(self.val_dataset): + # if idx >= 10: + # break + # print('In GQADataModule') + # question = sample["text"] + # label = sample["gqa_label"] + # answer = self.id2answer[label] + # print(f"Question: {question}\nLabel: {label}\nAnswer: {answer}") \ No newline at end of file diff --git a/vilt/datamodules/multitask_datamodule.py b/vilt/datamodules/multitask_datamodule.py index e2ce1ad..6d75efb 100644 --- a/vilt/datamodules/multitask_datamodule.py +++ b/vilt/datamodules/multitask_datamodule.py @@ -32,7 +32,7 @@ def prepare_data(self): def setup(self, stage): for dm in self.dms: dm.setup(stage) - + self.train_dataset = ConcatDataset([dm.train_dataset for dm in self.dms]) self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.dms]) self.test_dataset = ConcatDataset([dm.test_dataset for dm in self.dms]) diff --git a/vilt/datasets/__init__.py b/vilt/datasets/__init__.py index 95ec6a3..ba244ad 100644 --- a/vilt/datasets/__init__.py +++ b/vilt/datasets/__init__.py @@ -5,3 +5,4 @@ from .sbu_caption_dataset import SBUCaptionDataset from .vqav2_dataset import VQAv2Dataset from .nlvr2_dataset import NLVR2Dataset +from .gqa_dataset import GQADataset diff --git a/vilt/datasets/base_dataset.py b/vilt/datasets/base_dataset.py index 470a942..cf63b0b 100644 --- a/vilt/datasets/base_dataset.py +++ b/vilt/datasets/base_dataset.py @@ -29,7 +29,6 @@ def __init__( """ assert len(transform_keys) >= 1 super().__init__() - self.transforms = keys_to_transforms(transform_keys, size=image_size) self.text_column_name = text_column_name self.names = names @@ -38,8 +37,8 @@ def __init__( self.draw_false_text = draw_false_text self.image_only = image_only self.data_dir = data_dir - if len(names) != 0: + print(f"Attempting to load the following dataset files: {[f'{data_dir}/{name}.arrow' for name in names]}") tables = [ pa.ipc.RecordBatchFileReader( pa.memory_map(f"{data_dir}/{name}.arrow", "r") @@ -47,12 +46,29 @@ def __init__( for name in names if os.path.isfile(f"{data_dir}/{name}.arrow") ] - - self.table_names = list() + + self.table = list() for i, name in enumerate(names): - self.table_names += [name] * len(tables[i]) + if i < len(tables): + self.table += [name] * len(tables[i]) + else: + print(f"Warning: Skipping {name} as the index is out of range in tables.") + + if self.table is None: + print("Error: The table is not properly loaded. Please check the dataset files and their paths.") + if len(tables) > 0: + self.table = pa.concat_tables(tables, promote=True) + else: + print("Warning: No tables to concatenate. Check if dataset is properly loaded.") + self.table = None + + # if self.table is not None: + # print("Column names in the table schema:", [field.name for field in self.table.schema]) + # print("Sample answer_scores:") + # for i in range(min(10, len(self.table))): + # print(f"Row {i}: {self.table['answer_scores'][i].as_py()}") + - self.table = pa.concat_tables(tables, promote=True) if text_column_name != "": self.text_column_name = text_column_name self.all_texts = self.table[text_column_name].to_pandas().tolist() @@ -65,6 +81,22 @@ def __init__( self.all_texts = list() else: self.all_texts = list() + + + if self.table is not None and text_column_name != "": + self.text_column_name = text_column_name + try: + self.all_texts = self.table[text_column_name].to_pandas().tolist() + self.all_texts = ( + [list(set(texts)) for texts in self.all_texts] + if remove_duplicate + else self.all_texts + ) + except KeyError: + print(f"Error: The text column '{text_column_name}' was not found in the table.") + self.all_texts = list() + else: + self.all_texts = list() self.index_mapper = dict() @@ -77,6 +109,8 @@ def __init__( else: for i in range(len(self.table)): self.index_mapper[i] = (i, None) + print("Length of tables:", [len(t) for t in tables]) + @property def corpus(self): diff --git a/vilt/datasets/gqa_dataset.py b/vilt/datasets/gqa_dataset.py new file mode 100644 index 0000000..c989902 --- /dev/null +++ b/vilt/datasets/gqa_dataset.py @@ -0,0 +1,56 @@ +from .base_dataset import BaseDataset + +class GQADataset(BaseDataset): + def __init__(self, *args, split="", **kwargs): + assert split in ["train", "val", "test", "testdev"] + self.split = split + self.print_counter = 0 # Add this line to initialize the counter + + if split == "train": + names = ["gqa_train", "gqa_trainable_val"] + elif split == "val": + names = ["gqa_rest_val"] + elif split == "test": + names = ["gqa_testdev"] + + super().__init__( + *args, + **kwargs, + names=names, + text_column_name="questions", + remove_duplicate=False, + ) + + def __getitem__(self, index): + image_tensor = self.get_image(index)["image"] + text = self.get_text(index)["text"] + + index, question_index = self.index_mapper[index] + qid = self.table["question_id"][index][question_index].as_py() + + if self.split != "test": + answers = self.table["answers"][index][question_index].as_py() + labels = self.table["answer_label"][index][question_index].as_py() + scores = self.table["answer_scores"][index][question_index].as_py() + else: + answers = list() + labels = list() + scores = list() + + # Print the first 5 questions and answers + # if self.print_counter < 5: + # print('In GQADataset') + # print(f"Question: {text}") + # print(f"Label: {labels}") + # print(f"Answers: {answers}") + # self.print_counter += 1 # Increment the counter + + return { + "image": image_tensor, + "text": text, + "gqa_answer": answers, + "gqa_label": labels, + "gqa_scores": scores, + "qid": qid, + } + diff --git a/vilt/gadgets/my_metrics.py b/vilt/gadgets/my_metrics.py index 64a3f7f..e798a84 100644 --- a/vilt/gadgets/my_metrics.py +++ b/vilt/gadgets/my_metrics.py @@ -67,3 +67,30 @@ def update(self, logits, target): def compute(self): return self.score / self.total + +class GQAScore(Metric): + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, logits, target): + logits, target = ( + logits.detach().float().to(self.score.device), + target.detach().float().to(self.score.device), + ) + logits = torch.max(logits, 1)[1] + one_hots = torch.zeros(*target.size()).to(target) + one_hots.scatter_(1, logits.view(-1, 1), 1) + scores = one_hots * target + # print(f'Target: {target}') + # print(f'One Hots: {one_hots}') + # print(f'Scores: {scores}') + + self.score += scores.sum() + self.total += len(logits) + # print(self.score) + + + def compute(self): + return self.score / self.total diff --git a/vilt/modules/objectives.py b/vilt/modules/objectives.py index 434f040..df64b30 100644 --- a/vilt/modules/objectives.py +++ b/vilt/modules/objectives.py @@ -335,6 +335,41 @@ def compute_vqa(pl_module, batch): return ret +def compute_gqa(pl_module, batch): + infer = pl_module.infer(batch, mask_text=False, mask_image=False) + gqa_logits = pl_module.gqa_classifier(infer["cls_feats"]) + gqa_targets = torch.zeros( + len(gqa_logits), pl_module.hparams.config["gqa_label_size"] + ).to(pl_module.device) + + gqa_labels = batch["gqa_label"] + gqa_scores = batch["gqa_scores"] + for i, _label in enumerate(gqa_labels): + gqa_targets[i, _label] = torch.tensor(gqa_scores[i], device=pl_module.device) + + gqa_loss = ( + F.binary_cross_entropy_with_logits(gqa_logits, gqa_targets) + * gqa_targets.shape[1] + ) + + ret = { + "gqa_loss": gqa_loss, + "gqa_logits": gqa_logits, + "gqa_targets": gqa_targets, + "gqa_label": gqa_labels, + "gqa_scores": gqa_scores, + } + + phase = "train" if pl_module.training else "val" + loss = getattr(pl_module, f"{phase}_gqa_loss")(ret["gqa_loss"]) + score = getattr(pl_module, f"{phase}_gqa_score")( + ret["gqa_logits"], ret["gqa_targets"] + ) + pl_module.log(f"gqa/{phase}/loss", loss) + pl_module.log(f"gqa/{phase}/score", score) + + return ret + def compute_nlvr2(pl_module, batch): infer1 = pl_module.infer( @@ -588,6 +623,21 @@ def vqa_test_step(pl_module, batch, output): qids = batch["qid"] return {"qids": qids, "preds": vqa_preds} +def gqa_test_step(pl_module, batch, output): + id2answer = ( + pl_module.trainer.datamodule.dm_dicts["gqa_trainval"].id2answer + if "gqa_trainval" in pl_module.trainer.datamodule.dm_dicts + else pl_module.trainer.datamodule.dm_dicts["gqa"].id2answer + ) + gqa_logits = output["gqa_logits"] + gqa_preds = gqa_logits.argmax(dim=-1) + gqa_preds = [id2answer[pred.item()] for pred in gqa_preds] + questions = batch["text"] + qids = batch["qid"] + for q, a in zip(questions, gqa_preds): + print(f'Question: {q}, Answer: {a}') + return {"qids": qids, "preds": gqa_preds} + def arc_test_step(pl_module, batch, output): return output @@ -621,6 +671,34 @@ def vqa_test_wrapup(outs, model_name): torch.distributed.barrier() os.remove(f"vqa_submit_{rank}.json") +def gqa_test_wrapup(outs, model_name): + rank = torch.distributed.get_rank() + qids, preds = list(), list() + for out in outs: + qids += out["qids"] + preds += out["preds"] + + rets = list() + for qid, pred in zip(qids, preds): + rets.append({"question_id": qid, "answer": pred}) + with open(f"gqa_submit_{rank}.json", "w") as fp: + json.dump(rets, fp, indent=4) + + torch.distributed.barrier() + + if rank == 0: + jsons = list() + paths = list(glob.glob("gqa_submit_*.json")) + for path in paths: + with open(path, "r") as fp: + jsons += json.load(fp) + os.makedirs("result", exist_ok=True) + with open(f"result/gqa_submit_{model_name}.json", "w") as fp: + json.dump(jsons, fp, indent=4) + + torch.distributed.barrier() + os.remove(f"gqa_submit_{rank}.json") + def arc_test_wrapup(outs, caplen, model_name): rank = torch.distributed.get_rank() diff --git a/vilt/modules/vilt_module.py b/vilt/modules/vilt_module.py index 0a678f9..3f193d1 100644 --- a/vilt/modules/vilt_module.py +++ b/vilt/modules/vilt_module.py @@ -6,6 +6,8 @@ from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings from vilt.modules import heads, objectives, vilt_utils +def contains_nan(tensor): + return torch.any(torch.isnan(tensor)) class ViLTransformerSS(pl.LightningModule): def __init__(self, config): @@ -73,6 +75,17 @@ def __init__(self, config): nn.Linear(hs * 2, vs), ) self.vqa_classifier.apply(objectives.init_weights) + + if self.hparams.config["loss_names"]["gqa"] > 0: + gqa_label_size = self.hparams.config["gqa_label_size"] + + self.gqa_classifier = nn.Sequential( + nn.Linear(hs, hs * 2), + nn.LayerNorm(hs * 2), + nn.GELU(), + nn.Linear(hs * 2, gqa_label_size), + ) + self.gqa_classifier.apply(objectives.init_weights) if self.hparams.config["loss_names"]["nlvr2"] > 0: self.nlvr2_classifier = nn.Sequential( @@ -120,7 +133,6 @@ def infer( imgkey = f"image_{image_token_type_idx - 1}" else: imgkey = "image" - do_mlm = "_mlm" if mask_text else "" text_ids = batch[f"text_ids{do_mlm}"] text_labels = batch[f"text_labels{do_mlm}"] @@ -180,7 +192,6 @@ def infer( "text_masks": text_masks, "patch_index": patch_index, } - return ret def forward(self, batch): @@ -204,6 +215,9 @@ def forward(self, batch): # Visual Question Answering if "vqa" in self.current_tasks: ret.update(objectives.compute_vqa(self, batch)) + + if "gqa" in self.current_tasks: + ret.update(objectives.compute_gqa(self, batch)) # Natural Language for Visual Reasoning 2 if "nlvr2" in self.current_tasks: @@ -212,7 +226,6 @@ def forward(self, batch): # Image Retrieval and Text Retrieval if "irtr" in self.current_tasks: ret.update(objectives.compute_irtr(self, batch)) - return ret def training_step(self, batch, batch_idx): @@ -220,6 +233,15 @@ def training_step(self, batch, batch_idx): output = self(batch) total_loss = sum([v for k, v in output.items() if "loss" in k]) + questions = batch["text"] + logits = output["gqa_logits"] + p = torch.argmax(logits, dim=-1) + d = self.trainer.datamodule.dm_dicts["gqa"].id2answer + a = batch['gqa_label'] + preds = [d[pred.item()] for pred in p] + answers = [d[pred] for pred in a] + # for i, j, k in zip(questions, preds, answers): + # print(f'Question: {i}\nPred: {j}, Actual: {k}') return total_loss def training_epoch_end(self, outs): @@ -239,14 +261,22 @@ def test_step(self, batch, batch_idx): if self.hparams.config["loss_names"]["vqa"] > 0: ret.update(objectives.vqa_test_step(self, batch, output)) + + if self.hparams.config["loss_names"]["gqa"] > 0: + ret.update(objectives.gqa_test_step(self, batch, output)) + return ret def test_epoch_end(self, outs): model_name = self.hparams.config["load_path"].split("/")[-1][:-5] - + print(model_name) if self.hparams.config["loss_names"]["vqa"] > 0: objectives.vqa_test_wrapup(outs, model_name) + + if self.hparams.config["loss_names"]["gqa"] > 0: + objectives.gqa_test_wrapup(outs, model_name) + vilt_utils.epoch_wrapup(self) def configure_optimizers(self): diff --git a/vilt/modules/vilt_utils.py b/vilt/modules/vilt_utils.py index 8ff4804..1f4f70d 100644 --- a/vilt/modules/vilt_utils.py +++ b/vilt/modules/vilt_utils.py @@ -8,7 +8,7 @@ ) from vilt.modules.dist_utils import all_gather from vilt.modules.objectives import compute_irtr_recall -from vilt.gadgets.my_metrics import Accuracy, VQAScore, Scalar +from vilt.gadgets.my_metrics import Accuracy, VQAScore, GQAScore, Scalar def set_metrics(pl_module): @@ -19,6 +19,9 @@ def set_metrics(pl_module): if k == "vqa": setattr(pl_module, f"{split}_vqa_score", VQAScore()) setattr(pl_module, f"{split}_{k}_loss", Scalar()) + elif k == "gqa": + setattr(pl_module, f"{split}_gqa_score", GQAScore()) + setattr(pl_module, f"{split}_{k}_loss", Scalar()) elif k == "nlvr2": if split == "train": setattr(pl_module, f"train_{k}_accuracy", Accuracy()) @@ -83,6 +86,15 @@ def epoch_wrapup(pl_module): getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), ) getattr(pl_module, f"{phase}_{loss_name}_loss").reset() + elif loss_name == "gqa": + value = getattr(pl_module, f"{phase}_{loss_name}_score").compute() + pl_module.log(f"{loss_name}/{phase}/score_epoch", value) + getattr(pl_module, f"{phase}_{loss_name}_score").reset() + pl_module.log( + f"{loss_name}/{phase}/loss_epoch", + getattr(pl_module, f"{phase}_{loss_name}_loss").compute(), + ) + getattr(pl_module, f"{phase}_{loss_name}_loss").reset() elif loss_name == "nlvr2": if phase == "train": value = getattr(pl_module, f"train_{loss_name}_accuracy").compute() diff --git a/vilt/utils/write_gqa.py b/vilt/utils/write_gqa.py new file mode 100644 index 0000000..c96997c --- /dev/null +++ b/vilt/utils/write_gqa.py @@ -0,0 +1,128 @@ +import json +import pandas as pd +import pyarrow as pa +import os + +from tqdm import tqdm +from glob import glob +from collections import defaultdict + +def get_score(): + return 1.0 + +def path2rest(path, split, annotations, image_id_str, answer_label_dict): + with open(path, "rb") as fp: + binary = fp.read() + + _annot = annotations[split][image_id_str] + _annot = list(_annot.items()) + qids, qas = [a[0] for a in _annot], [a[1] for a in _annot] + questions = [qa[0] for qa in qas] + if split == "train" or split == "val": + answers = [qa[1] for qa in qas] + answer_label = ( + [answer_label_dict.setdefault(answer, len(answer_label_dict) + 1) for answer in answers] + ) + else: + answers = [] + answer_label=[] + + answer_scores = [get_score() for _ in answers] if "test" not in split else list() + + return [binary, questions, answers, answer_label, answer_scores, image_id_str, qids, split] + +def split_val_dataset(dataset_root, arrow_filename, split_ratio=0.9): + table = pa.ipc.RecordBatchFileReader( + pa.memory_map(f"{dataset_root}/{arrow_filename}", "r") + ).read_all() + + pdtable = table.to_pandas() + + split_index = int(len(pdtable) * split_ratio) + df1 = pdtable[:split_index] + df2 = pdtable[split_index:] + + df1 = pa.Table.from_pandas(df1) + df2 = pa.Table.from_pandas(df2) + + with pa.OSFile(f"{dataset_root}/gqa_trainable_val.arrow", "wb") as sink: + with pa.RecordBatchFileWriter(sink, df1.schema) as writer: + writer.write_table(df1) + + with pa.OSFile(f"{dataset_root}/gqa_rest_val.arrow", "wb") as sink: + with pa.RecordBatchFileWriter(sink, df2.schema) as writer: + writer.write_table(df2) + + +def make_arrow(root, dataset_root): + # Read question files + answer_label_dict = dict() + answer_label_counter = 1 + question_files = { + "train": [f"{root}/questions1.2/train_balanced_questions.json"], + "val": [f"{root}/questions1.2/val_balanced_questions.json"], + "test": [f"{root}/questions1.2/test_balanced_questions.json"], + "testdev": [f"{root}/questions1.2/testdev_all_questions.json"] + } + + annotations = dict() + + # for split in ["train", "val"]: + for split in ["test", "testdev"]: + _annot = defaultdict(dict) + for question_file in question_files[split]: + with open(question_file, "r") as fp: + questions = json.load(fp) + for q_id, q in tqdm(questions.items()): + if split == "test": + _annot[q["imageId"]][q_id] = [q["question"]] + else: + _annot[q["imageId"]][q_id] = [q["question"], q["answer"]] + + annotations[split] = _annot + + # for split in ["train", "val"]: + for split in ["test", "testdev"]: + paths = list(glob(f"{root}/images/*.jpg")) + annot_paths=[] + for path in paths: + image_id_str = path.split("/")[-1].split(".")[0] + if image_id_str in annotations[split]: + annot_paths.append(path) + + print(f'{split}: {len(paths)}, {len(annot_paths)}, {len(annotations[split])}') + + bs = [] + for path in tqdm(annot_paths): + image_id_str = path.split("/")[-1].split(".")[0] + bs.append(path2rest(path, split, annotations, image_id_str, answer_label_dict)) + + dataframe = pd.DataFrame( + bs, + columns=[ + "image", + "questions", + "answers", + "answer_label", # Add answer_label here + "answer_scores", + "image_id", + "question_id", + "split", + ], + ) + print(dataframe['questions'].iloc[:5]) + print(dataframe['answers'].iloc[:5]) + print(dataframe['answer_label'].iloc[:5]) + print(dataframe['answer_scores'].iloc[:5]) + print(dataframe['image_id'].iloc[:5]) + print(dataframe['question_id'].iloc[:5]) + print(answer_label_dict) + + table = pa.Table.from_pandas(dataframe) + + os.makedirs(dataset_root, exist_ok=True) + with pa.OSFile(f"{dataset_root}/gqa_{split}.arrow", "wb") as sink: + with pa.RecordBatchFileWriter(sink, table.schema) as writer: + writer.write_table(table) + + split_val_dataset(dataset_root, "gqa_val.arrow") \ No newline at end of file From 140d309778c182a3b812621827cddbed4946e7ae Mon Sep 17 00:00:00 2001 From: Shashoo <57130111+Shashwath-kumar@users.noreply.github.com> Date: Mon, 1 May 2023 20:29:11 -0400 Subject: [PATCH 2/3] Updated readme with ViLT model finetuned on GQA --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5de4f24..c94c75a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ We provide five pretrained weights 3. ViLT-B/32 200k finetuned on NLVR2 [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_nlvr2.ckpt) 4. ViLT-B/32 200k finetuned on COCO IR/TR [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_irtr_coco.ckpt) 5. ViLT-B/32 200k finetuned on F30K IR/TR [link](https://github.com/dandelin/ViLT/releases/download/200k/vilt_irtr_f30k.ckpt) +6. ViLT-B/32 200k finetuned on GQA [link](https://github.com/keshavshivkumar/ViLT/releases/download/vilt_gqa/vilt_gqa.ckpt) ## Out-of-the-box MLM + Visualization Demo

From 8c57142455575056b87ccc465426c9f112526ab1 Mon Sep 17 00:00:00 2001 From: GOUTHAM SWAMINATHAN <52864198+GoldenHorde42@users.noreply.github.com> Date: Tue, 2 May 2023 18:17:05 -0400 Subject: [PATCH 3/3] Update EVAL.md Added GQA evaluation instructions. --- EVAL.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/EVAL.md b/EVAL.md index bef55d1..2895705 100644 --- a/EVAL.md +++ b/EVAL.md @@ -1,6 +1,15 @@ # Evaluation The results will vary a bit since we do a batched-inference, which yields padded image batch that would be inconsistently embedded while performing linear image patch projection. +## Evaluate GQA +```bash +python run.py with data_root= num_gpus= num_nodes= per_gpu_batchsize= task_finetune_gqa_randaug test_only=True precision=32 load_path="/vilt_gqa.ckpt" + +ex) +python run.py with data_root=/data2/dsets/dataset num_gpus=8 num_nodes=1 per_gpu_batchsize=64 task_finetune_gqa_randaug test_only=True precision=32 load_path="weights/vilt_gqa.ckpt" + +output > This script will generate `result/gqa_submit_last.json` +``` ## Evaluate VQAv2 ```bash python run.py with data_root= num_gpus= num_nodes= per_gpu_batchsize= task_finetune_vqa_randaug test_only=True precision=32 load_path="/vilt_vqa.ckpt"