Skip to content

Commit

Permalink
Add VQA evaluation script (#198)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #198

Add the script to perform evaluation on GQA dataset.

Test Plan:
```
CUBLAS_WORKSPACE_CONFIG=:4096:8 python -m torch.distributed.launch --nproc_per_node=2 --use_env vqa_eval.py
--resume /data/home/ebs/data/mdetr/gqa_resnet101_checkpoint.pth --ema --eval
--dataset_config /data/home/ebs/torchmultimodal/examples/mdetr/vqa.json
```

Results:
{'answer_attr_accuracy': 0.814785373608903, 'answer_attr_loss': 0.593756134082187, 'answer_cat_accuracy': 0.9180445151033386, 'answer_cat_loss': 0.4044019418754108, 'answer_global_accuracy': 0.9894276629570747, 'answer_global_loss': 0.034145722320555585, 'answer_obj_accuracy': 0.9866454689984102, 'answer_obj_loss': 0.03846034316059049, 'answer_rel_accuracy': 0.6911764705882353, 'answer_rel_loss': 1.4625307823973612, 'answer_total_accuracy': 0.6164944356120827, 'answer_type_accuracy': 0.9996820349761526, 'answer_type_loss': 0.0005808583157603491, 'giou_loss': 0.0, 'l1_loss': 0.0, 'soft_token_loss': 2.1158602026763518, 'total_loss': 4.649735950760318}

Reviewed By: RdoubleA

Differential Revision: D38160123

Pulled By: ebsmothers

fbshipit-source-id: 12b7f139ea36d4e117822bfea1a9e8401e020a4c
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Jul 28, 2022
1 parent 2b330aa commit 479b68f
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 1 deletion.
93 changes: 92 additions & 1 deletion examples/mdetr/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Optional

import torch
from examples.mdetr.data.dataset import build_flickr, collate_fn
from examples.mdetr.data.dataset import build_flickr, build_gqa, collate_fn
from examples.mdetr.data.transforms import MDETRTransform
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -45,3 +45,94 @@ def val_dataloader(self):
collate_fn=partial(collate_fn, self.tokenizer),
)
return data_loader_val


class GQADataModule(LightningDataModule):
def __init__(self, dataset_config, tokenizer: Optional[Callable] = None):
super().__init__()
self.dataset_config = dataset_config
self.distributed = dataset_config.distributed
self.batch_size = dataset_config.batch_size
self.epoch_chunks = dataset_config.epoch_chunks
self.tokenizer = tokenizer

def setup(self, stage: Optional[str] = None):
if self.tokenizer is None:
self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
self.train_transform = MDETRTransform(self.tokenizer, is_train=True)
self.val_transform = MDETRTransform(self.tokenizer, is_train=False)

if stage == "train":
self.train = build_gqa(
stage, self.tokenizer, self.train_transform, self.dataset_config
)
if stage == "val":
self.val = build_gqa(
stage, self.tokenizer, self.val_transform, self.dataset_config
)

def train_dataloader(self):
# To handle very big datasets, we chunk it into smaller parts.
if self.epoch_chunks > 0:
print(
f"Splitting the training set into {self.epoch_chunks} chunks of size approximately "
f" {len(self.train) // self.epoch_chunks}"
)
chunks = torch.chunk(torch.arange(len(self.train)), self.epoch_chunks)
datasets = [
torch.utils.data.Subset(self.train, chunk.tolist()) for chunk in chunks
]
if self.distributed:
self.samplers_train = [
DistributedSampler(ds, shuffle=True) for ds in datasets
]
else:
self.samplers_train = [
torch.utils.data.RandomSampler(ds) for ds in datasets
]

batch_samplers_train = [
torch.utils.data.BatchSampler(
sampler_train, self.batch_size, drop_last=True
)
for sampler_train in self.samplers_train
]
assert len(batch_samplers_train) == len(datasets)
train_dataloaders = [
DataLoader(
ds,
batch_sampler=batch_sampler_train,
collate_fn=partial(collate_fn, self.tokenizer),
)
for ds, batch_sampler_train in zip(datasets, batch_samplers_train)
]
return train_dataloaders
else:
if self.distributed:
self.sampler_train = DistributedSampler(self.train, shuffle=True)
else:
self.sampler_train = torch.utils.data.RandomSampler(self.train)
batch_sampler_train = torch.utils.data.BatchSampler(
self.sampler_train, self.batch_size, drop_last=True
)
train_dataloader = DataLoader(
self.train,
batch_sampler=batch_sampler_train,
collate_fn=partial(collate_fn, self.tokenizer),
)
return train_dataloader

def val_dataloader(self):
if self.distributed:
sampler = DistributedSampler(self.val, shuffle=False)
else:
sampler = torch.utils.data.SequentialSampler(self.val)

data_loader_val = DataLoader(
self.val,
batch_size=self.batch_size,
sampler=sampler,
drop_last=False,
collate_fn=partial(collate_fn, self.tokenizer),
)
return data_loader_val
161 changes: 161 additions & 0 deletions examples/mdetr/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from pathlib import Path

import torch
Expand Down Expand Up @@ -60,6 +61,110 @@ def __getitem__(self, idx):
return img, target


GQA_TYPE_TO_ID = {"obj": 0, "attr": 1, "rel": 2, "global": 3, "cat": 4}


class GQADataset(CocoDetection):
def __init__(
self, img_folder, ann_file, transforms, return_tokens, tokenizer, ann_folder
):
super(GQADataset, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_tokens, tokenizer=tokenizer)
with open(ann_folder / "gqa_answer2id.json", "r") as f:
self.answer2id = json.load(f)
with open(ann_folder / "gqa_answer2id_by_type.json", "r") as f:
self.answer2id_by_type = json.load(f)
self.type_to_id = GQA_TYPE_TO_ID

def __getitem__(self, idx):
img, target = super(GQADataset, self).__getitem__(idx)
image_id = self.ids[idx]
coco_img = self.coco.loadImgs(image_id)[0]
caption = coco_img["caption"]
dataset_name = coco_img["dataset_name"]
question_id = coco_img["questionId"]
target = {"image_id": image_id, "annotations": target, "caption": caption}
img, target = self.prepare(img, target)
if self._transforms is not None:
img, target = self._transforms(img, target)
target["dataset_name"] = dataset_name
target["questionId"] = question_id

if coco_img["answer"] not in self.answer2id:
answer = "unknown"
else:
answer = coco_img["answer"]

target["answer"] = torch.as_tensor(self.answer2id[answer], dtype=torch.long)
target["answer_type"] = torch.as_tensor(
self.type_to_id[coco_img["question_type"]], dtype=torch.long
)
target["answer_type_mask"] = {
f"answer_{k}": torch.BoolTensor([True])
if coco_img["question_type"] == k
else torch.BoolTensor([False])
for k in self.type_to_id.keys()
}
target["answer_type_mask"]["answer_type"] = torch.BoolTensor([True])

if coco_img["answer"] not in self.answer2id_by_type["answer_attr"]:
answer = "unknown"
else:
answer = coco_img["answer"]
target["answer_attr"] = torch.as_tensor(
self.answer2id_by_type["answer_attr"][answer]
if coco_img["question_type"] == "attr"
else -100,
dtype=torch.long,
)

if coco_img["answer"] not in self.answer2id_by_type["answer_global"]:
answer = "unknown"
else:
answer = coco_img["answer"]
target["answer_global"] = torch.as_tensor(
self.answer2id_by_type["answer_global"][answer]
if coco_img["question_type"] == "global"
else -100,
dtype=torch.long,
)

if coco_img["answer"] not in self.answer2id_by_type["answer_rel"]:
answer = "unknown"
else:
answer = coco_img["answer"]
target["answer_rel"] = torch.as_tensor(
self.answer2id_by_type["answer_rel"][answer]
if coco_img["question_type"] == "rel"
else -100,
dtype=torch.long,
)

if coco_img["answer"] not in self.answer2id_by_type["answer_cat"]:
answer = "unknown"
else:
answer = coco_img["answer"]
target["answer_cat"] = torch.as_tensor(
self.answer2id_by_type["answer_cat"][answer]
if coco_img["question_type"] == "cat"
else -100,
dtype=torch.long,
)

if coco_img["answer"] not in self.answer2id_by_type["answer_obj"]:
answer = "unknown"
else:
answer = coco_img["answer"]
target["answer_obj"] = torch.as_tensor(
self.answer2id_by_type["answer_obj"][answer]
if coco_img["question_type"] == "obj"
else -100,
dtype=torch.long,
)
return img, target


def collate_fn(tokenizer, batch):
batch = list(zip(*batch))
final_batch = {}
Expand Down Expand Up @@ -151,3 +256,59 @@ def build_flickr(image_set, tokenizer, transform, args):
is_train=is_train,
)
return dataset


def build_gqa(image_set, tokenizer, transform, args):
img_dir = Path(args.vg_img_path)
assert img_dir.exists(), f"provided VG img path {img_dir} does not exist"

assert args.gqa_split_type is not None

if image_set == "train":
datasets = []
for imset in ["train", "val"]:
ann_file = (
Path(args.gqa_ann_path)
/ f"finetune_gqa_{imset}_{args.gqa_split_type}.json"
)

datasets.append(
GQADataset(
img_dir,
ann_file,
transforms=transform,
return_tokens=True,
tokenizer=tokenizer,
ann_folder=Path(args.gqa_ann_path),
)
)

return torch.utils.data.ConcatDataset(datasets)
elif image_set == "val":
ann_file = Path(args.gqa_ann_path) / "finetune_gqa_testdev_balanced.json"

return GQADataset(
img_dir,
ann_file,
transforms=transform,
return_tokens=True,
tokenizer=tokenizer,
ann_folder=Path(args.gqa_ann_path),
)
elif image_set in ["test", "challenge", "testdev", "submission"]:
ann_file = (
Path(args.gqa_ann_path)
/ f"finetune_gqa_{image_set}_{args.gqa_split_type}.json"
)

return GQADataset(
img_dir,
ann_file,
transforms=transform,
return_tokens=True,
tokenizer=tokenizer,
ann_folder=Path(args.gqa_ann_path),
)

else:
raise ValueError(f"Unknown image set {image_set}")
7 changes: 7 additions & 0 deletions examples/mdetr/vqa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"combine_datasets": ["gqa"],
"combine_datasets_val": ["gqa"],
"vg_img_path": "",
"gqa_ann_path": "",
"gqa_split_type": ""
}
Loading

0 comments on commit 479b68f

Please sign in to comment.