-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #87 from nisyad-ms/nisyad/add_kvp_adapter_for_vqa_…
…datasets Nisyad/add kvp adapter for vqa datasets
- Loading branch information
Showing
8 changed files
with
185 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import unittest | ||
|
||
from tests.test_fixtures import VQATestFixtures | ||
from vision_datasets.common.constants import DatasetTypes | ||
from vision_datasets.key_value_pair.manifest import KeyValuePairLabelManifest | ||
from vision_datasets.visual_question_answering import VQAAsKeyValuePairDataset | ||
|
||
|
||
class TestVQAAsKeyValuePairDataset(unittest.TestCase): | ||
def test_vqa_to_kvp(self): | ||
sample_vqa_dataset, tempdir = VQATestFixtures.create_a_vqa_dataset() | ||
with tempdir: | ||
kvp_dataset = VQAAsKeyValuePairDataset(sample_vqa_dataset) | ||
|
||
self.assertIsInstance(kvp_dataset, VQAAsKeyValuePairDataset) | ||
self.assertEqual(kvp_dataset.dataset_info.type, DatasetTypes.KEY_VALUE_PAIR) | ||
self.assertIn("name", kvp_dataset.dataset_info.schema) | ||
self.assertIn("description", kvp_dataset.dataset_info.schema) | ||
self.assertIn("fieldSchema", kvp_dataset.dataset_info.schema) | ||
|
||
self.assertEqual(kvp_dataset.dataset_info.schema["fieldSchema"], | ||
{'answer': {'type': 'string', 'description': 'Answer to the question.'}, | ||
'rationale': {'type': 'string', 'description': 'Rationale for the answer.'}}) | ||
|
||
_, target, _ = kvp_dataset[0] | ||
self.assertIsInstance(target, KeyValuePairLabelManifest) | ||
self.assertEqual(target.label_data, | ||
{'fields': {'answer': {'value': 'answer 1'}}, 'text': {'question': 'question 1'}}) | ||
|
||
self.assertEqual(len(kvp_dataset), 3) | ||
self.assertEqual(len(kvp_dataset.dataset_manifest.images), 2) | ||
|
||
# Last image has 2 questions associated with it | ||
self.assertEqual(kvp_dataset[-2][0][0].size, kvp_dataset[-1][0][0].size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .coco_manifest_adaptor import VisualQuestionAnswerinCocoManifestAdaptor | ||
from .manifest import VisualQuestionAnsweringLabelManifest | ||
from .operations import VisualQuestionAnsweringCocoDictGenerator | ||
from .vqa_as_kvp_dataset import VQAAsKeyValuePairDataset | ||
|
||
__all__ = ['VisualQuestionAnswerinCocoManifestAdaptor', 'VisualQuestionAnsweringLabelManifest', 'VisualQuestionAnsweringCocoDictGenerator'] | ||
__all__ = ['VisualQuestionAnswerinCocoManifestAdaptor', 'VisualQuestionAnsweringLabelManifest', 'VisualQuestionAnsweringCocoDictGenerator', 'VQAAsKeyValuePairDataset'] |
99 changes: 99 additions & 0 deletions
99
vision_datasets/visual_question_answering/vqa_as_kvp_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from copy import deepcopy | ||
import logging | ||
from typing import Any, Dict | ||
|
||
from vision_datasets.common import DatasetTypes, KeyValuePairDatasetInfo, VisionDataset | ||
from vision_datasets.key_value_pair import ( | ||
KeyValuePairDatasetManifest, | ||
KeyValuePairLabelManifest, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VQAAsKeyValuePairDataset(VisionDataset): | ||
"""Dataset class that access Visual Question Answering (VQA) datset as KeyValuePair dataset.""" | ||
|
||
ANSWER_KEY = "answer" | ||
RATIONALE_KEY = "rationale" | ||
QUESTION_KEY = "question" | ||
SCHEMA_BASE = { | ||
"name": "Visual Question Answering", | ||
"description": "Answer questions on given images and provide rationale for the answer.", | ||
"fieldSchema": { | ||
ANSWER_KEY: { | ||
"type": "string", | ||
"description": "Answer to the question.", | ||
}, | ||
RATIONALE_KEY: { | ||
"type": "string", | ||
"description": "Rationale for the answer.", | ||
}, | ||
} | ||
} | ||
|
||
def __init__(self, vqa_dataset: VisionDataset): | ||
""" | ||
Initializes an instance of the VQAAsKeyValuePairDataset class. | ||
Args: | ||
vqa_dataset (VisionDataset): The VQA dataset to convert to key-value pair dataset. | ||
""" | ||
|
||
if vqa_dataset is None or vqa_dataset.dataset_info.type is not DatasetTypes.VISUAL_QUESTION_ANSWERING: | ||
raise ValueError("Input dataset must be a Visual Question Answering dataset.") | ||
|
||
# Generate schema and update dataset info | ||
vqa_dataset = deepcopy(vqa_dataset) | ||
|
||
dataset_info_dict = vqa_dataset.dataset_info.__dict__ | ||
dataset_info_dict["type"] = DatasetTypes.KEY_VALUE_PAIR.name.lower() | ||
self.img_id_to_pos = {x.id: i for i, x in enumerate(vqa_dataset.dataset_manifest.images)} | ||
|
||
schema = self.construct_schema() | ||
# Update dataset_info with schema | ||
dataset_info = KeyValuePairDatasetInfo({**dataset_info_dict, "schema": schema}) | ||
|
||
# Construct KeyValuePairDatasetManifest | ||
annotations = [] | ||
id = 1 | ||
for _, img in enumerate(vqa_dataset.dataset_manifest.images, 1): | ||
label_data = [label.label_data for label in img.labels] | ||
|
||
for label in label_data: | ||
kvp_label_data = self.construct_kvp_label_data(label) | ||
img_ids = [self.img_id_to_pos[img.id]] # 0-based index | ||
kvp_annotation = KeyValuePairLabelManifest(id, img_ids, label_data=kvp_label_data) | ||
id += 1 | ||
|
||
# KVPDatasetManifest expects img.labels to be empty. Labels are instead stored in KVP annotation | ||
img.labels = [] | ||
annotations.append(kvp_annotation) | ||
|
||
dataset_manifest = KeyValuePairDatasetManifest(vqa_dataset.dataset_manifest.images, annotations, schema, additional_info=vqa_dataset.dataset_manifest.additional_info) | ||
super().__init__(dataset_info, dataset_manifest, dataset_resources=vqa_dataset.dataset_resources) | ||
|
||
def construct_schema(self) -> Dict[str, Any]: | ||
return self.SCHEMA_BASE | ||
|
||
def construct_kvp_label_data(self, label: Dict[str, str]) -> Dict[str, Dict[str, str]]: | ||
""" | ||
Convert the VQA dataset label to the desired format for KVP annotation as defined by the SCHEMA_BASE. | ||
E.g. {"fields": | ||
{"answer": {"value": "yes"}}, | ||
"text": {"question": "Is there a dog in the image?"} | ||
} | ||
""" | ||
|
||
if self.QUESTION_KEY not in label: | ||
raise KeyError(f"Question key '{self.QUESTION_KEY}' not found in label.") | ||
if self.ANSWER_KEY not in label: | ||
raise KeyError(f"Answer key '{self.ANSWER_KEY}' not found in label.") | ||
|
||
kvp_label_data = { | ||
KeyValuePairLabelManifest.LABEL_KEY: { | ||
self.ANSWER_KEY: {KeyValuePairLabelManifest.LABEL_VALUE_KEY: label[self.ANSWER_KEY]}, | ||
}, | ||
KeyValuePairLabelManifest.TEXT_INPUT_KEY: {self.QUESTION_KEY: label[self.QUESTION_KEY]}, | ||
} | ||
|
||
return kvp_label_data |