forked from facebookresearch/mmf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Add VinVL dataset and docs (facebookresearch#1162)
Summary: Pull Request resolved: facebookresearch#1162 Add the VinVL datast and builder enabling pretraining and finetuning over arbitrary datasets in MMF. Added docstrings and updated VinVL project docs explaining usable. Test Plan: Imported from OSS **Static Docs Preview: mmf** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V9/mmf/)| |**Modified Pages**| |[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V9/mmf/docs/projects/vinvl/)||[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V8/mmf/docs/projects/vinvl/)||[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V7/mmf/docs/projects/vinvl/)||[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V6/mmf/docs/projects/vinvl/)||[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V5/mmf/docs/projects/vinvl/)||[docs/projects/vinvl](https://our.intern.facebook.com/intern/staticdocs/eph/D32773454/V4/mmf/docs/projects/vinvl/)| Reviewed By: apsdehal Differential Revision: D32773454 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 2cc63a4e95445d6b92c849603ae9d3693675164c
- Loading branch information
1 parent
7634dc7
commit 6e1f7b4
Showing
6 changed files
with
263 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
includes: | ||
- ../vqa2/defaults.yaml | ||
|
||
dataset_config: | ||
vinvl: | ||
base_dataset_name: vqa2 | ||
label_map: /private/home/ryanjiang/winoground/pretrained_models/VG-SGG-dicts-vgoi6-clipped.json | ||
base_dataset: ${dataset_config.vqa2} | ||
processors: | ||
text_processor: | ||
type: vinvl_text_tokenizer | ||
params: | ||
mask_probability: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
__all__ = ["VinVLBuilder", "VinVLDataset"] | ||
|
||
from .builder import VinVLBuilder | ||
from .dataset import VinVLDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
from mmf.common.registry import registry | ||
from mmf.datasets.builders.vinvl.dataset import VinVLDataset | ||
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder | ||
from omegaconf import open_dict | ||
|
||
|
||
@registry.register_builder("vinvl") | ||
class VinVLBuilder(MMFDatasetBuilder): | ||
def __init__( | ||
self, dataset_name="vinvl", dataset_class=VinVLDataset, *args, **kwargs | ||
): | ||
super().__init__(dataset_name, dataset_class, dataset_type="train_val") | ||
self.dataset_class = VinVLDataset | ||
|
||
@classmethod | ||
def config_path(cls): | ||
return "configs/datasets/vinvl/defaults.yaml" | ||
|
||
def load(self, config, dataset_type, *args, **kwargs): | ||
"""The VinVL dataset is a dataset that augments an existing | ||
dataset within MMF. VinVL requires unique inputs for | ||
finetuning and pretraining unsupported by general datasets. | ||
To enable this functionality on arbitrary datasets, | ||
the VinVL dataset contains a base dataset, | ||
and returns an augmented version of samples from the | ||
base dataset. | ||
For more details, read the VinVL dataset docstring. | ||
The Builder: | ||
This class is a builder for the VinVL dataset. | ||
As the VinVL dataset must be constructed with an instance to | ||
a base dataset, configured by the client in the VinVL configs | ||
yaml. This builder class instantiates 2 datasets, then | ||
passes the base dataset to the VinVL dataset instance. | ||
The VinVL config is expected to have the following stucture, | ||
```yaml | ||
dataset_config: | ||
vinvl: | ||
base_dataset_name: vqa2 | ||
label_map: <path to label map> | ||
base_dataset: ${dataset_config.vqa2} | ||
processors: | ||
text_processor: | ||
type: vinvl_text_tokenizer | ||
params: | ||
... | ||
``` | ||
Where base_dataset is the yaml config for the base dataset | ||
in this example vqa2. | ||
And base_dataset_name is vqa2. | ||
Returns: | ||
VinVLDataset: Instance of the VinVLDataset class which contains | ||
an base dataset instance. | ||
""" | ||
base_dataset_name = config.get("base_dataset_name", "vqa2") | ||
base_dataset_config = config.get("base_dataset", config) | ||
# instantiate base dataset | ||
# instantiate base dataser builder | ||
base_dataset_builder_class = registry.get_builder_class(base_dataset_name) | ||
base_dataset_builder_instance = base_dataset_builder_class() | ||
# build base dataset instance | ||
base_dataset_builder_instance.build_dataset(base_dataset_config) | ||
base_dataset = base_dataset_builder_instance.load_dataset( | ||
base_dataset_config, dataset_type | ||
) | ||
if hasattr(base_dataset_builder_instance, "update_registry_for_model"): | ||
base_dataset_builder_instance.update_registry_for_model(base_dataset_config) | ||
|
||
# instantiate vinvl dataset | ||
vinvl_text_processor = config["processors"]["text_processor"] | ||
with open_dict(base_dataset_config): | ||
base_dataset_config["processors"]["text_processor"] = vinvl_text_processor | ||
base_dataset_config["label_map"] = config["label_map"] | ||
|
||
vinvl_dataset = super().load(base_dataset_config, dataset_type, *args, **kwargs) | ||
vinvl_dataset.set_base_dataset(base_dataset) | ||
return vinvl_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import json | ||
import logging | ||
import random | ||
|
||
from mmf.datasets.mmf_dataset import MMFDataset | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VinVLDataset(MMFDataset): | ||
"""The VinVL dataset is a dataset that augments an existing | ||
dataset within MMF. VinVL requires unique inputs for | ||
finetuning and pretraining unsupported by general datasets. | ||
To enable this functionality on arbitrary datasets, | ||
the VinVL dataset contains a base dataset, | ||
and returns an augmented version of samples from the | ||
base dataset. | ||
For example, the VQA2 dataset may return a sample {image, text} | ||
The VinVL dataset when asked for a sample, will return | ||
{image, text', rand_caption, rand_label} | ||
text' = text + labels | ||
rand_caption = text from a random example | ||
rand_label = obj detection labels text for a random example | ||
Why does this exist? | ||
VinVL samples contain rand_caption, and rand_label which require | ||
random choice from the annotations db, and features_db. | ||
Currently general text_processors do not have access to these | ||
databases, instead randomness like mismatched_captions in | ||
masked coco are implemented on the dataset level. | ||
To support VinVL finetuning and pretraining on general datasets, | ||
without a major refactor, the VinVL builder and dataset introduce | ||
a new design pattern to enable processor access to databases. | ||
Interface and Assumptions: | ||
The VinVL dataset assumes: | ||
The sample returned by the base dataset contains a key "text" | ||
with string text. | ||
There exists a label_map json file path in the dataset config | ||
for a json obj containing idx_to_attribute and idx_to_label | ||
maps. VinVL OD uses VG labels, and this map can be downloaded | ||
from https://penzhanwu2.blob.core.windows.net/sgg/ | ||
sgg_benchmark/vinvl_model_zoo/VG-SGG-dicts-vgoi6-clipped.json | ||
The features_db points to features generated from the VinVL | ||
feature extraction script, consult the VinVL feature | ||
extraction tutorial for more details. | ||
""" | ||
|
||
def __init__(self, config, dataset_type, *args, **kwargs): | ||
if "name" in kwargs: | ||
name = kwargs["name"] | ||
elif "dataset_name" in kwargs: | ||
name = kwargs["dataset_name"] | ||
else: | ||
name = "vinvl" | ||
super().__init__(name, config, dataset_type, *args, **kwargs) | ||
self.add_tags = not "test" == self._dataset_type | ||
self.label_map = self.load_label_map(config.get("label_map")) | ||
|
||
def set_base_dataset(self, base_dataset): | ||
self.base_dataset = base_dataset | ||
|
||
def init_processors(self): | ||
super().init_processors() | ||
|
||
def __len__(self): | ||
return len(self.annotation_db) | ||
|
||
def __getitem__(self, idx): | ||
return self.load_item(idx) | ||
|
||
def load_item(self, idx): | ||
base_sample = self.base_dataset.load_item(idx) | ||
# assumes sample contains key "text" that is the string text | ||
# when using on vqa2 which returns tokens under key "text" | ||
# change the vqa2 dataset class to return "text" | ||
text_processor_argument = {"text": base_sample["text"]} | ||
if self.add_tags: | ||
text_processor_argument["text_b"] = self.get_label_str(base_sample) | ||
|
||
random_caption_idx = random.randint(0, len(self.annotation_db) - 1) | ||
random_caption_sample = self.base_dataset.load_item(random_caption_idx) | ||
random_caption = random_caption_sample["text"] | ||
text_processor_argument["random_captions"] = [random_caption] | ||
|
||
random_labels_idx = random.randint(0, len(self.annotation_db) - 1) | ||
random_labels_sample = self.base_dataset.load_item(random_labels_idx) | ||
random_image_tags_str = self.get_label_str(random_labels_sample) | ||
text_processor_argument["random_labels"] = [random_image_tags_str] | ||
|
||
processed_caption = self.text_processor(text_processor_argument) | ||
base_sample.update(processed_caption) | ||
return base_sample | ||
|
||
def load_label_map(self, map_path): | ||
with open(map_path) as f: | ||
return json.loads(f.read()) | ||
|
||
def get_label_str(self, sample): | ||
image_labels = sample["image_info_0"].get("labels", []) | ||
label_map = self.label_map.get("idx_to_label", {}) | ||
label_str = " ".join([label_map.get(str(id), "") for id in image_labels]) | ||
image_attr_labels = sample["image_info_0"].get("attr_labels", []) | ||
attr_map = self.label_map.get("idx_to_attribute", {}) | ||
attr_str = " ".join([attr_map.get(str(id), "") for id in image_attr_labels]) | ||
accum_str = label_str + " " + attr_str | ||
return accum_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters