Skip to content

Commit

Permalink
[feat] Add classification fine-tuning utilities
Browse files Browse the repository at this point in the history
- The PR aims at ending starter classification utils to flava examples.

As of now the PR adds following things:
- Finetuning trainer
- Classification FLAVA
- TorchVisionDataModule for easy composability of datasets from
torchvision
- Some changes to MLP module for more generalization
- Some improvements/bug fixes to original FLAVA code
- Splits the datamodules to better service their individual concerns.

TODOs:
- Add support for rest of the datasets. This involves levaraging the
existing datamodules that we created in this PR along with support for
seamlessly plugging different dataset
- Add command line overriding on top
- Add support for retrieval, zero-shot and other downstream tasks in an
easily accessible form
- Expose more things from the model other than just the loss

Test Plan:

The code is not in 100% working stage. I have tested only the changes in
my PR. I expect everything to be stable by the end of the stack.

ghstack-source-id: 2c0b03cde9ca54f662c20c4f6d40b73cc1b306cb
Pull Request resolved: #8
  • Loading branch information
apsdehal committed Apr 1, 2022
1 parent cbf5d53 commit e6d8222
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 80 deletions.
222 changes: 176 additions & 46 deletions examples/flava/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import random
import warnings
from dataclasses import dataclass, field
Expand All @@ -19,7 +20,12 @@
from PIL import Image, UnidentifiedImageError
from pytorch_lightning import LightningDataModule
from torchvision.datasets import ImageFolder
from transformers import BertTokenizer, DataCollatorForLanguageModeling
from transformers import (
BertTokenizer,
DefaultDataCollator,
DataCollatorForLanguageModeling,
TRANSFORMERS_CACHE,
)
from transforms import (
RandomResizedCropAndInterpolationWithTwoPic,
MaskingGenerator,
Expand All @@ -28,6 +34,10 @@
)


PRETRAINING_IMAGE_MEAN = (0.48145466, 0.4578275, 0.40821073)
PRETRAINING_IMAGE_STD = (0.26862954, 0.26130258, 0.27577711)


class MaskedImageModelingTransform:
def __init__(
self,
Expand All @@ -36,8 +46,8 @@ def __init__(
scale: Tuple[float, float] = (0.9, 1.0),
encoder_interpolation: str = "bicubic",
codebook_interpolation: str = "lanczos",
image_mean: Tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073),
image_std: Tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711),
image_mean: Tuple[float, float, float] = PRETRAINING_IMAGE_MEAN,
image_std: Tuple[float, float, float] = PRETRAINING_IMAGE_STD,
mask_window_size: int = 14,
mask_num_patches: int = 75,
mask_max_patches: Optional[int] = None,
Expand Down Expand Up @@ -97,7 +107,7 @@ def __call__(self, images: Union[List[Image.Image], Image.Image]):
return self.transform(images)


def default_image_transforms():
def default_image_pretraining_transforms():
return MaskedImageModelingTransform(), MaskedImageModelingTransform()


Expand All @@ -117,9 +127,9 @@ def __init__(
val_root: str,
transforms: Optional[Tuple[Callable, Callable]] = None,
use_subset_sampler: bool = False,
batch_size=32,
num_workers=4,
allow_uneven_batches=False,
batch_size: int = 32,
num_workers: int = 4,
allow_uneven_batches: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -131,7 +141,7 @@ def __init__(
self.use_subset_sampler = use_subset_sampler

if transforms is None:
transforms = default_image_transforms()
transforms = default_image_pretraining_transforms()

self.train_transform, self.test_transform = transforms

Expand Down Expand Up @@ -168,6 +178,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=sampler,
shuffle=True,
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
# ideally, we don't need to drop these for unimodal cases
Expand All @@ -181,6 +192,7 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=None,
shuffle=False,
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
# ideally, we don't need to drop these for unimodal cases
Expand All @@ -204,7 +216,7 @@ def _default_split_key_mapping():


@dataclass
class HFDatasetsInfo:
class HFDatasetInfo:
key: str
subset: str
remove_columns: Optional[List[str]] = None
Expand All @@ -215,9 +227,17 @@ class HFDatasetsInfo:
)


def _build_datasets_from_info(
dataset_infos: List[HFDatasetsInfo], split: str = "train"
):
@dataclass
class TorchVisionDatasetInfo:
key: str
class_ptr: torch.utils.data.Dataset
train_split: str = "train"
val_split: str = "val"
has_val: bool = True
test_split: str = "test"


def _build_datasets_from_info(dataset_infos: List[HFDatasetInfo], split: str = "train"):
dataset_list = []
for dataset_info in dataset_infos:
current_dataset = load_dataset(
Expand Down Expand Up @@ -246,27 +266,23 @@ def _encode_text(text, tokenizer, *args, **kwargs):
return tokenizer(text, *args, **kwargs)


class MLMDataModule(LightningDataModule):
class TextDataModule(LightningDataModule):
def __init__(
self,
dataset_infos: List[HFDatasetsInfo],
dataset_infos: List[HFDatasetInfo],
tokenizer: Optional[Callable] = None,
mlm_probablity: float = 0.15,
max_length: int = 512,
batch_size=32,
num_workers=4,
ignore_index=-1,
allow_uneven_batches=False,
batch_size: int = 32,
num_workers: int = 4,
allow_uneven_batches: bool = False,
**kwargs,
):
super().__init__()
self.dataset_infos = dataset_infos
self.tokenizer = tokenizer
self.mlm_probability = mlm_probablity
self.max_length = max_length
self.batch_size = batch_size
self.num_workers = num_workers
self.ignore_index = ignore_index
self.allow_uneven_batches = allow_uneven_batches

def setup(self, stage=None):
Expand All @@ -292,37 +308,24 @@ def setup(self, stage=None):
self.val_dataset.set_transform(transform)

def train_dataloader(self):
return torch.utils.data.DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=None,
collate_fn=self._build_collator(),
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
# ideally, we don't need to drop these for unimodal cases
# but just to be safe
drop_last=True,
)
return self._build_dataloader(self.train_dataset)

def val_dataloader(self):
return self._build_dataloader(self.val_dataset, shuffle=False)

def _build_dataloader(self, dataset, drop_last=False, shuffle=True):
return torch.utils.data.DataLoader(
self.val_dataset,
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=None,
shuffle=shuffle,
collate_fn=self._build_collator(),
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
# ideally, we don't need to drop these for unimodal cases
# but just to be safe
drop_last=True,
drop_last=drop_last,
)

def _build_collator(self):
return DataCollatorForLanguageModeling(
self.tokenizer, mlm_probability=self.mlm_probability
)
return DefaultDataCollator()

def on_before_batch_transfer(self, batch, *args):
batch.pop("token_type_ids", None)
Expand All @@ -331,6 +334,35 @@ def on_before_batch_transfer(self, batch, *args):
batch = pad_batch(batch, self.batch_size)
return batch

def on_after_batch_transfer(self, batch, *args):
batch["text_masked"] = batch.pop("input_ids")
return batch


class MLMDataModule(TextDataModule):
def __init__(
self,
dataset_infos: List[HFDatasetInfo],
mlm_probability: float = 0.15,
ignore_index: int = -1,
**kwargs: Any,
):
super().__init__(dataset_infos, **kwargs)
self.mlm_probability = mlm_probability
self.ignore_index = ignore_index

def _build_dataloader(self, dataset, drop_last=True):
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
# ideally, we don't need to drop these for unimodal cases
# but just to be safe
return self._build_dataloader(dataset, drop_last=drop_last)

def _build_collator(self):
return DataCollatorForLanguageModeling(
self.tokenizer, mlm_probability=self.mlm_probability
)

def on_after_batch_transfer(self, batch, *args):
batch["text_masked"] = batch.pop("input_ids")
batch["mlm_labels"] = batch.pop("labels")
Expand Down Expand Up @@ -404,8 +436,8 @@ def fetch_images(sample, timeout):
class VLDataModule(LightningDataModule):
def __init__(
self,
train_dataset_infos: List[HFDatasetsInfo],
val_dataset_infos: List[HFDatasetsInfo],
train_dataset_infos: List[HFDatasetInfo],
val_dataset_infos: List[HFDatasetInfo],
text_tokenizer: Optional[Callable] = None,
image_transforms: Optional[Tuple[Callable, Callable]] = None,
mlm_probablity: float = 0.15,
Expand All @@ -414,7 +446,7 @@ def __init__(
num_workers: int = 4,
ignore_index: int = -1,
itm_probability: float = 0.1,
allow_uneven_batches=False,
allow_uneven_batches: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -423,7 +455,7 @@ def __init__(
self.val_dataset_infos = val_dataset_infos

if image_transforms is None:
image_transforms = default_image_transforms()
image_transforms = default_image_pretraining_transforms()

self.train_image_transform, self.test_image_transform = image_transforms
self.text_tokenizer = text_tokenizer
Expand Down Expand Up @@ -488,6 +520,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=None,
shuffle=True,
collate_fn=self._build_collator(),
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
Expand All @@ -500,6 +533,7 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
sampler=None,
shuffle=False,
collate_fn=self._build_collator(),
# uneven batches can cause distributed issues,
# drop last batch to prevent those.
Expand Down Expand Up @@ -531,6 +565,102 @@ def on_after_batch_transfer(self, batch, *args):
return batch


FINETUNING_IMAGE_MEAN = (0.485, 0.456, 0.406)
FINETUNING_IMAGE_STD = (0.229, 0.224, 0.225)


def default_torchvision_transforms():
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=FINETUNING_IMAGE_MEAN,
std=FINETUNING_IMAGE_STD,
),
]
)
return transform, transform


class TorchVisionDataModule(LightningDataModule):
def __init__(
self,
dataset_info: TorchVisionDatasetInfo,
dataset_root: Optional[str] = None,
image_transforms: Optional[Tuple[Callable, Callable]] = None,
batch_size: int = 32,
num_workers: int = 4,
**kwargs,
):
super().__init__()

if dataset_root is None:
dataset_root = os.path.join(TRANSFORMERS_CACHE, "datasets", "torchvision")
dataset_root = os.path.join(
dataset_root, dataset_info.class_ptr.__name__.lower()
)
os.makedirs(dataset_root, exist_ok=True)

self.dataset_info = dataset_info
self.dataset_root = dataset_root
if image_transforms is None:
image_transforms = default_torchvision_transforms()
self.train_transform, self.test_transform = image_transforms
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage=None):
self.train_dataset = self.dataset_info.class_ptr(
self.dataset_root,
split=self.dataset_info.train_split,
transform=self.train_transform,
download=True,
)

if self.dataset_info.has_val:
self.val_dataset = self.dataset_info.class_ptr(
self.dataset_root,
split=self.dataset_info.val_split,
transform=self.test_transform,
download=True,
)

self.test_dataset = self.dataset_info.class_ptr(
self.dataset_root,
split=self.dataset_info.test_split,
transform=self.test_transform,
download=True,
)

def train_dataloader(self):
return self._build_dataloader(self.train_dataset)

def val_dataloader(self):
if self.dataset_info.has_val:
dataset = self.val_dataset
else:
dataset = self.test_dataset

return self._build_dataloader(dataset, shuffle=False)

def test_dataloader(self):
return self._build_dataloader(self.test_dataset, shuffle=False)

def _build_dataloader(self, dataset: torch.utils.data.Dataset, shuffle=True):
return torch.utils.data.DataLoader(
dataset,
shuffle=shuffle,
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def on_before_batch_transfer(self, batch, *args):
images, targets = batch
batch = {"image": images, "labels": targets}
return batch


class MultiDataLoader:
# NOTE: Please check MMF's MultiDataLoader if you want to support
# size proportional strategies or epoch based runs
Expand Down
Loading

0 comments on commit e6d8222

Please sign in to comment.