Skip to content

Commit

Permalink
[feat] Add VinVL dataset and docs (facebookresearch#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan-Qiyu-Jiang authored and facebook-github-bot committed Dec 17, 2021
1 parent 7634dc7 commit 6e1f7b4
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 5 deletions.
13 changes: 13 additions & 0 deletions mmf/configs/datasets/vinvl/defaults.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
includes:
- ../vqa2/defaults.yaml

dataset_config:
vinvl:
base_dataset_name: vqa2
label_map: /private/home/ryanjiang/winoground/pretrained_models/VG-SGG-dicts-vgoi6-clipped.json
base_dataset: ${dataset_config.vqa2}
processors:
text_processor:
type: vinvl_text_tokenizer
params:
mask_probability: 0
5 changes: 5 additions & 0 deletions mmf/datasets/builders/vinvl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
__all__ = ["VinVLBuilder", "VinVLDataset"]

from .builder import VinVLBuilder
from .dataset import VinVLDataset
86 changes: 86 additions & 0 deletions mmf/datasets/builders/vinvl/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from mmf.common.registry import registry
from mmf.datasets.builders.vinvl.dataset import VinVLDataset
from mmf.datasets.mmf_dataset_builder import MMFDatasetBuilder
from omegaconf import open_dict


@registry.register_builder("vinvl")
class VinVLBuilder(MMFDatasetBuilder):
def __init__(
self, dataset_name="vinvl", dataset_class=VinVLDataset, *args, **kwargs
):
super().__init__(dataset_name, dataset_class, dataset_type="train_val")
self.dataset_class = VinVLDataset

@classmethod
def config_path(cls):
return "configs/datasets/vinvl/defaults.yaml"

def load(self, config, dataset_type, *args, **kwargs):
"""The VinVL dataset is a dataset that augments an existing
dataset within MMF. VinVL requires unique inputs for
finetuning and pretraining unsupported by general datasets.
To enable this functionality on arbitrary datasets,
the VinVL dataset contains a base dataset,
and returns an augmented version of samples from the
base dataset.
For more details, read the VinVL dataset docstring.
The Builder:
This class is a builder for the VinVL dataset.
As the VinVL dataset must be constructed with an instance to
a base dataset, configured by the client in the VinVL configs
yaml. This builder class instantiates 2 datasets, then
passes the base dataset to the VinVL dataset instance.
The VinVL config is expected to have the following stucture,
```yaml
dataset_config:
vinvl:
base_dataset_name: vqa2
label_map: <path to label map>
base_dataset: ${dataset_config.vqa2}
processors:
text_processor:
type: vinvl_text_tokenizer
params:
...
```
Where base_dataset is the yaml config for the base dataset
in this example vqa2.
And base_dataset_name is vqa2.
Returns:
VinVLDataset: Instance of the VinVLDataset class which contains
an base dataset instance.
"""
base_dataset_name = config.get("base_dataset_name", "vqa2")
base_dataset_config = config.get("base_dataset", config)
# instantiate base dataset
# instantiate base dataser builder
base_dataset_builder_class = registry.get_builder_class(base_dataset_name)
base_dataset_builder_instance = base_dataset_builder_class()
# build base dataset instance
base_dataset_builder_instance.build_dataset(base_dataset_config)
base_dataset = base_dataset_builder_instance.load_dataset(
base_dataset_config, dataset_type
)
if hasattr(base_dataset_builder_instance, "update_registry_for_model"):
base_dataset_builder_instance.update_registry_for_model(base_dataset_config)

# instantiate vinvl dataset
vinvl_text_processor = config["processors"]["text_processor"]
with open_dict(base_dataset_config):
base_dataset_config["processors"]["text_processor"] = vinvl_text_processor
base_dataset_config["label_map"] = config["label_map"]

vinvl_dataset = super().load(base_dataset_config, dataset_type, *args, **kwargs)
vinvl_dataset.set_base_dataset(base_dataset)
return vinvl_dataset
110 changes: 110 additions & 0 deletions mmf/datasets/builders/vinvl/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import json
import logging
import random

from mmf.datasets.mmf_dataset import MMFDataset


logger = logging.getLogger(__name__)


class VinVLDataset(MMFDataset):
"""The VinVL dataset is a dataset that augments an existing
dataset within MMF. VinVL requires unique inputs for
finetuning and pretraining unsupported by general datasets.
To enable this functionality on arbitrary datasets,
the VinVL dataset contains a base dataset,
and returns an augmented version of samples from the
base dataset.
For example, the VQA2 dataset may return a sample {image, text}
The VinVL dataset when asked for a sample, will return
{image, text', rand_caption, rand_label}
text' = text + labels
rand_caption = text from a random example
rand_label = obj detection labels text for a random example
Why does this exist?
VinVL samples contain rand_caption, and rand_label which require
random choice from the annotations db, and features_db.
Currently general text_processors do not have access to these
databases, instead randomness like mismatched_captions in
masked coco are implemented on the dataset level.
To support VinVL finetuning and pretraining on general datasets,
without a major refactor, the VinVL builder and dataset introduce
a new design pattern to enable processor access to databases.
Interface and Assumptions:
The VinVL dataset assumes:
The sample returned by the base dataset contains a key "text"
with string text.
There exists a label_map json file path in the dataset config
for a json obj containing idx_to_attribute and idx_to_label
maps. VinVL OD uses VG labels, and this map can be downloaded
from https://penzhanwu2.blob.core.windows.net/sgg/
sgg_benchmark/vinvl_model_zoo/VG-SGG-dicts-vgoi6-clipped.json
The features_db points to features generated from the VinVL
feature extraction script, consult the VinVL feature
extraction tutorial for more details.
"""

def __init__(self, config, dataset_type, *args, **kwargs):
if "name" in kwargs:
name = kwargs["name"]
elif "dataset_name" in kwargs:
name = kwargs["dataset_name"]
else:
name = "vinvl"
super().__init__(name, config, dataset_type, *args, **kwargs)
self.add_tags = not "test" == self._dataset_type
self.label_map = self.load_label_map(config.get("label_map"))

def set_base_dataset(self, base_dataset):
self.base_dataset = base_dataset

def init_processors(self):
super().init_processors()

def __len__(self):
return len(self.annotation_db)

def __getitem__(self, idx):
return self.load_item(idx)

def load_item(self, idx):
base_sample = self.base_dataset.load_item(idx)
# assumes sample contains key "text" that is the string text
# when using on vqa2 which returns tokens under key "text"
# change the vqa2 dataset class to return "text"
text_processor_argument = {"text": base_sample["text"]}
if self.add_tags:
text_processor_argument["text_b"] = self.get_label_str(base_sample)

random_caption_idx = random.randint(0, len(self.annotation_db) - 1)
random_caption_sample = self.base_dataset.load_item(random_caption_idx)
random_caption = random_caption_sample["text"]
text_processor_argument["random_captions"] = [random_caption]

random_labels_idx = random.randint(0, len(self.annotation_db) - 1)
random_labels_sample = self.base_dataset.load_item(random_labels_idx)
random_image_tags_str = self.get_label_str(random_labels_sample)
text_processor_argument["random_labels"] = [random_image_tags_str]

processed_caption = self.text_processor(text_processor_argument)
base_sample.update(processed_caption)
return base_sample

def load_label_map(self, map_path):
with open(map_path) as f:
return json.loads(f.read())

def get_label_str(self, sample):
image_labels = sample["image_info_0"].get("labels", [])
label_map = self.label_map.get("idx_to_label", {})
label_str = " ".join([label_map.get(str(id), "") for id in image_labels])
image_attr_labels = sample["image_info_0"].get("attr_labels", [])
attr_map = self.label_map.get("idx_to_attribute", {})
attr_str = " ".join([attr_map.get(str(id), "") for id in image_attr_labels])
accum_str = label_str + " " + attr_str
return accum_str
3 changes: 2 additions & 1 deletion projects/vinvl/configs/vqa2/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ model_config:
num_labels: 3129

dataset_config:
vqa2:
vinvl:
base_dataset_name: vqa2
processors:
text_processor:
type: vinvl_text_tokenizer
Expand Down
51 changes: 47 additions & 4 deletions website/docs/projects/vinvl.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,59 @@ dataset. Or running parallel VinVL feature extraction on an image directory usin
change the feature paths in the dataset config in `mmf/configs/datasets/<your dataset name>/defaults.yaml`
to point to your new features.

## The VinVL Dataset
The VinVL dataset is a dataset wrapper that augments an existing dataset within MMF.
The VinVL dataset doesn't contain new images or text, but introduces label and attribute tags as strings.
VinVL requires unique inputs for finetuning and pretraining unsupported by general datasets.
To enable this functionality on arbitrary datasets,the VinVL dataset contains a base dataset,
and returns an augmented version of samples from thebase dataset.

For example,
the VQA2 dataset may return a sample {image, text}
The VinVL dataset when asked for a sample, will return
{image, text', rand_caption, rand_label}
text' = text + labels
rand_caption = text from a random example
rand_label = obj detection labels text for a random example

The VinVL dataset assumes:
The sample returned by the base dataset contains a key "text" with string text.
There exists a label_map json file path in the dataset configfor a json obj containing idx_to_attribute and idx_to_labelmaps.
VinVL OD uses VG labels, and this map can be downloaded from https://penzhanwu2.blob.core.windows.net/sgg/sgg_benchmark/vinvl_model_zoo/VG-SGG-dicts-vgoi6-clipped.json
The features_db points to features generated from the VinVLfeature extraction script,
consult the VinVL feature extraction tutorial for more details.

This is why for VinVL finetuning and pretraining you should use dataset=vinvl,
then specify your base dataset in your configs.
Here is an example from projects/vinvl/configs/vqa2/defaults.yaml.

```yaml
includes:
- ../vqa2/defaults.yaml

dataset_config:
vinvl:
base_dataset_name: vqa2
label_map: /private/home/ryanjiang/winoground/pretrained_models/VG-SGG-dicts-vgoi6-clipped.json
base_dataset: ${dataset_config.vqa2}
processors:
text_processor:
type: vinvl_text_tokenizer
params:
mask_probability: 0
```
Where vqa2/defaults.yaml contains the feature paths pointing to VinVL features.
## Training
After extracting features and redirecting your dataset config,
to train VinVL model from scratch on the VQA2.0 dataset, run the following command
```
mmf_run config=projects/vinvl/configs/vqa2/defaults.yaml run_type=train_val dataset=vqa2 model=vinvl
```bash
mmf_run config=projects/vinvl/configs/vqa2/defaults.yaml run_type=train dataset=vinvl model=vinvl
```

To finetune a pretrained VinVL model on the VQA2.0 dataset, run the following command
```
mmf_run config=projects/vinvl/configs/vqa2/defaults.yaml run_type=train_val dataset=vqa2 model=vinvl checkpoint.resume_zoo=vinvl.pretrained
```bash
mmf_run config=projects/vinvl/configs/vqa2/defaults.yaml run_type=train dataset=vinvl model=vinvl checkpoint.resume_zoo=vinvl.pretrained
```

0 comments on commit 6e1f7b4

Please sign in to comment.