Skip to content

Commit

Permalink
Merge pull request #87 from nisyad-ms/nisyad/add_kvp_adapter_for_vqa_…
Browse files Browse the repository at this point in the history
…datasets

Nisyad/add kvp adapter for vqa datasets
  • Loading branch information
cy-bai authored Sep 26, 2024
2 parents ee36122 + a8fb3a3 commit 99a7bef
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 3 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ print(imgs)
print(target)
```

### Loading IC/OD Datasets in KeyValuePair (KVP) Format:
### Loading IC/OD/VQA Datasets in KeyValuePair (KVP) Format:
You can convert an existing IC/OD VisionDataset to the generalized KVP format using the following adapter:

```{python}
Expand All @@ -109,6 +109,11 @@ from vision_datasets.image_object_detection import DetectionAsKeyValuePairDatase
sample_od_dataset = VisionDataset(dataset_info, dataset_manifest)
kvp_dataset = DetectionAsKeyValuePairDataset(sample_od_dataset)
kvp_dataset_for_multilabel_classification = DetectionAsKeyValuePairDatasetForMultilabelClassification(sample_od_dataset)
# For VQA dataset
from vision_datasets.visual_question_answering import VQAAsKeyValuePairDataset
sample_vqa_dataset = VisionDataset(dataset_info, dataset_manifest)
kvp_dataset = VQAAsKeyValuePairDataset(sample_vqa_dataset)
```


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import setuptools
from os import path

VERSION = '1.0.16'
VERSION = '1.0.17'

# Get the long description from the README file
here = path.abspath(path.dirname(__file__))
Expand Down
34 changes: 34 additions & 0 deletions tests/test_coco_iris_to_kvp_wrapper/test_vqa_as_kvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from tests.test_fixtures import VQATestFixtures
from vision_datasets.common.constants import DatasetTypes
from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest
from vision_datasets.visual_question_answering import VQAAsKeyValuePairDataset


class TestVQAAsKeyValuePairDataset(unittest.TestCase):
def test_vqa_to_kvp(self):
sample_vqa_dataset, tempdir = VQATestFixtures.create_a_vqa_dataset()
with tempdir:
kvp_dataset = VQAAsKeyValuePairDataset(sample_vqa_dataset)

self.assertIsInstance(kvp_dataset, VQAAsKeyValuePairDataset)
self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR)
self.assertIn("name", kvp_dataset.dataset_info.schema)
self.assertIn("description", kvp_dataset.dataset_info.schema)
self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema)

self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"],
{'answer': {'type': 'string', 'description': 'Answer to the question.'},
'rationale': {'type': 'string', 'description': 'Rationale for the answer.'}})

_, target, _ = kvp_dataset[0]
self.assertIsInstance(target, KeyValuePairLabelManifest)
self.assertEqual(target.label_data,
{'fields': {'answer': {'value': 'answer 1'}}, 'text': {'question': 'question 1'}})

self.assertEqual(len(kvp_dataset), 3)
self.assertEqual(len(kvp_dataset.dataset_manifest.images), 2)

# Last image has 2 questions associated with it
self.assertEqual(kvp_dataset[-2][0][0].size, kvp_dataset[-1][0][0].size)
43 changes: 43 additions & 0 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import pathlib
import tempfile

from PIL import Image

from vision_datasets.common import (
Expand Down Expand Up @@ -137,3 +138,45 @@ def create_an_ic_manifest(root_dir='', n_images=2, n_categories=3):
coco_path = pathlib.Path(root_dir) / 'coco.json'
coco_path.write_text(json.dumps(coco_dict))
return CocoManifestAdaptorFactory.create(DatasetTypes.IMAGE_CLASSIFICATION_MULTILABEL).create_dataset_manifest(coco_path.name, root_dir)


class VQATestFixtures:
DATASET_INFO_DICT = {
"name": "dummy",
"version": 1,
"type": "visual_question_answering",
"root_folder": "dummy",
"format": "coco",
"test": {
"index_path": "train.json",
"files_for_local_usage": [
"train.zip"
]
},
}

@staticmethod
def create_a_vqa_dataset(n_images=2):
dataset_dict = copy.deepcopy(VQATestFixtures.DATASET_INFO_DICT)
tempdir = tempfile.TemporaryDirectory()
dataset_dict['root_folder'] = tempdir.name
for i in range(n_images):
Image.new('RGB', (min(1000, (i+1) * 100), min(1000, (i+1) * 100))).save(pathlib.Path(tempdir.name) / f'{i + 1}.jpg')

dataset_info = DatasetInfo(dataset_dict)
dataset_manifest = VQATestFixtures().create_a_vqa_manifest(tempdir.name, n_images)
dataset = VisionDataset(dataset_info, dataset_manifest)
return dataset, tempdir

@staticmethod
def create_a_vqa_manifest(root_dir='', n_images=2):
images = [{'id': i + 1, 'file_name': f'{i + 1}.jpg', 'width': min(1000, (i+1)*100), 'height': min(1000, (i+1)*100)} for i in range(n_images)]
annotations = [{'id': i + 1, 'image_id': i + 1, 'question': f'question {i+1}', 'answer': f'answer {i+1}'} for i in range(n_images)]

# Add a second question for the last image
annotations.append({'id': n_images + 1, 'image_id': n_images, 'question': 'question 3', 'answer': 'answer 3'})

coco_dict = {'images': images, 'annotations': annotations}
coco_path = pathlib.Path(root_dir) / 'coco.json'
coco_path.write_text(json.dumps(coco_dict))
return CocoManifestAdaptorFactory.create(DatasetTypes.VISUAL_QUESTION_ANSWERING).create_dataset_manifest(coco_path.name, root_dir)
3 changes: 2 additions & 1 deletion vision_datasets/visual_question_answering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .coco_manifest_adaptor import VisualQuestionAnswerinCocoManifestAdaptor
from .manifest import VisualQuestionAnsweringLabelManifest
from .operations import VisualQuestionAnsweringCocoDictGenerator
from .vqa_as_kvp_dataset import VQAAsKeyValuePairDataset

__all__ = ['VisualQuestionAnswerinCocoManifestAdaptor', 'VisualQuestionAnsweringLabelManifest', 'VisualQuestionAnsweringCocoDictGenerator']
__all__ = ['VisualQuestionAnswerinCocoManifestAdaptor', 'VisualQuestionAnsweringLabelManifest', 'VisualQuestionAnsweringCocoDictGenerator', 'VQAAsKeyValuePairDataset']
99 changes: 99 additions & 0 deletions vision_datasets/visual_question_answering/vqa_as_kvp_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from copy import deepcopy
import logging
from typing import Any, Dict

from vision_datasets.common import DatasetTypes, KeyValuePairDatasetInfo, VisionDataset
from vision_datasets.key_value_pair import (
KeyValuePairDatasetManifest,
KeyValuePairLabelManifest,
)

logger = logging.getLogger(__name__)


class VQAAsKeyValuePairDataset(VisionDataset):
"""Dataset class that access Visual Question Answering (VQA) datset as KeyValuePair dataset."""

ANSWER_KEY = "answer"
RATIONALE_KEY = "rationale"
QUESTION_KEY = "question"
SCHEMA_BASE = {
"name": "Visual Question Answering",
"description": "Answer questions on given images and provide rationale for the answer.",
"fieldSchema": {
ANSWER_KEY: {
"type": "string",
"description": "Answer to the question.",
},
RATIONALE_KEY: {
"type": "string",
"description": "Rationale for the answer.",
},
}
}

def __init__(self, vqa_dataset: VisionDataset):
"""
Initializes an instance of the VQAAsKeyValuePairDataset class.
Args:
vqa_dataset (VisionDataset): The VQA dataset to convert to key-value pair dataset.
"""

if vqa_dataset is None or vqa_dataset.dataset_info.type is not DatasetTypes.VISUAL_QUESTION_ANSWERING:
raise ValueError("Input dataset must be a Visual Question Answering dataset.")

# Generate schema and update dataset info
vqa_dataset = deepcopy(vqa_dataset)

dataset_info_dict = vqa_dataset.dataset_info.__dict__
dataset_info_dict["type"] = DatasetTypes.KEY_VALUE_PAIR.name.lower()
self.img_id_to_pos = {x.id: i for i, x in enumerate(vqa_dataset.dataset_manifest.images)}

schema = self.construct_schema()
# Update dataset_info with schema
dataset_info = KeyValuePairDatasetInfo({**dataset_info_dict, "schema": schema})

# Construct KeyValuePairDatasetManifest
annotations = []
id = 1
for _, img in enumerate(vqa_dataset.dataset_manifest.images, 1):
label_data = [label.label_data for label in img.labels]

for label in label_data:
kvp_label_data = self.construct_kvp_label_data(label)
img_ids = [self.img_id_to_pos[img.id]] # 0-based index
kvp_annotation = KeyValuePairLabelManifest(id, img_ids, label_data=kvp_label_data)
id += 1

# KVPDatasetManifest expects img.labels to be empty. Labels are instead stored in KVP annotation
img.labels = []
annotations.append(kvp_annotation)

dataset_manifest = KeyValuePairDatasetManifest(vqa_dataset.dataset_manifest.images, annotations, schema, additional_info=vqa_dataset.dataset_manifest.additional_info)
super().__init__(dataset_info, dataset_manifest, dataset_resources=vqa_dataset.dataset_resources)

def construct_schema(self) -> Dict[str, Any]:
return self.SCHEMA_BASE

def construct_kvp_label_data(self, label: Dict[str, str]) -> Dict[str, Dict[str, str]]:
"""
Convert the VQA dataset label to the desired format for KVP annotation as defined by the SCHEMA_BASE.
E.g. {"fields":
{"answer": {"value": "yes"}},
"text": {"question": "Is there a dog in the image?"}
}
"""

if self.QUESTION_KEY not in label:
raise KeyError(f"Question key '{self.QUESTION_KEY}' not found in label.")
if self.ANSWER_KEY not in label:
raise KeyError(f"Answer key '{self.ANSWER_KEY}' not found in label.")

kvp_label_data = {
KeyValuePairLabelManifest.LABEL_KEY: {
self.ANSWER_KEY: {KeyValuePairLabelManifest.LABEL_VALUE_KEY: label[self.ANSWER_KEY]},
},
KeyValuePairLabelManifest.TEXT_INPUT_KEY: {self.QUESTION_KEY: label[self.QUESTION_KEY]},
}

return kvp_label_data

0 comments on commit 99a7bef

Please sign in to comment.