From d84d1fc5d22eeaf56f5236a30e56a777a629a487 Mon Sep 17 00:00:00 2001 From: Tomonobu Tsujikawa Date: Wed, 30 Mar 2022 18:21:43 +0900 Subject: [PATCH] add image augmentation for csv dataloader. --- nnabla_nas/dataset/csv.py | 61 ++++++++++++++++++++++++++++- nnabla_nas/utils/data/transforms.py | 36 +++++++++++++++++ 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/nnabla_nas/dataset/csv.py b/nnabla_nas/dataset/csv.py index 78574f9c..93549e1a 100644 --- a/nnabla_nas/dataset/csv.py +++ b/nnabla_nas/dataset/csv.py @@ -17,6 +17,7 @@ from nnabla.utils.load import _create_dataset from .dataloader import BaseDataLoader +from ..utils.data import transforms def get_sliced_data_iterator(dataset, comm, training, portion): @@ -52,6 +53,7 @@ class DataLoader(BaseDataLoader): will be used for validation. Defaults to 1.0. This is only considered when searching is `True`. rng (:obj:`numpy.random.RandomState`), optional): Numpy random number generator. Defaults to None. + augmentation (dict, optional): Information on how to augment. Defaults to None. communicator (Communicator, optional): The communicator is used to support distributed learning. Defaults to None. """ @@ -59,8 +61,10 @@ class DataLoader(BaseDataLoader): def __init__(self, batch_size=1, searching=False, training=False, train_file=None, valid_file=None, train_cache_dir=None, valid_cache_dir=None, - train_portion=1.0, rng=None, communicator=None): + train_portion=1.0, rng=None, augmentation=None, + communicator=None): self.rng = rng or random.prng + self.augmentation = augmentation if searching: file = train_file @@ -93,3 +97,58 @@ def __len__(self): def next(self): x, y = self._data.next() return {"inputs": [x], "targets": [y]} + + def transform(self, key='train'): + r"""Return a transform applied to data augmentation.""" + assert key in ('train', 'valid') + + if self.augmentation: + type = self.augmentation.get('type') + norm = self.augmentation.get('normalize') + else: + type = None + norm = None + + if type == 'cifar10': + mean = (0.49139968, 0.48215827, 0.44653124) + std = (0.24703233, 0.24348505, 0.26158768) + scale = 1./255.0 + elif type == 'imagenet': + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + scale = 1./255.0 + else: + mean = (0.0, 0.0, 0.0) + std = (1.0, 1.0, 1.0) + scale = 1./255.0 + + if key == 'train': + if type == 'cifar10': + pad_width = (4, 4, 4, 4) + return transforms.Compose([ + transforms.Cutout(8, prob=1, seed=123), + transforms.Normalize(mean=mean, std=std, scale=scale), + transforms.RandomCrop((3, 32, 32), pad_width=pad_width), + transforms.RandomHorizontalFlip() + ]) + elif type == 'imagenet': + return transforms.Compose([ + transforms.Normalize(mean=mean, std=std, scale=scale), + transforms.RandomResizedCrop((3, 224, 224), + scale=(1.0, 2.3), ratio=1.33), + transforms.RandomHorizontalFlip() + ]) + else: + pass # same as valid + + if type == 'cifar10' or norm: + return transforms.Compose([ + transforms.Normalize(mean=mean, std=std, scale=scale) + ]) + elif type == 'imagenet': + return transforms.Compose([ + transforms.Resize(size=(224, 224)), + transforms.Normalize(mean=mean, std=std, scale=scale) + ]) + else: + return transforms.Compose([]) diff --git a/nnabla_nas/utils/data/transforms.py b/nnabla_nas/utils/data/transforms.py index be7a975c..07649885 100644 --- a/nnabla_nas/utils/data/transforms.py +++ b/nnabla_nas/utils/data/transforms.py @@ -178,6 +178,42 @@ def __str__(self): ) +class RandomResizedCrop(object): + r"""Crop a random portion of image and resize it. + + Args: + shape (tuple of `int`): The output image shape. + scale (tuple of `float`): lower and upper scale ratio when randomly + scaling the image. + ratio (`float`): The aspect ratio range when randomly deforming + the image. For example, to deform aspect ratio of image from + 1:1.3 to 1.3:1, specify "1.3". To not apply random deforming, + specify "1.0". + interpolation (str): Interpolation mode chosen from + ('linear'|'nearest'). The default is 'linear'. + """ + + def __init__(self, shape, scale=None, ratio=None, interpolation='linear'): + self._shape = shape + self._scale = scale + self._ratio = ratio + self._interpolation = interpolation + + def __call__(self, input): + return F.image_augmentation( + input, shape=self._shape, + min_scale=self._scale[0], max_scale=self._scale[1], + aspect_ratio=self._ratio) + + def __str__(self): + return self.__class__.__name__ + ( + f'(shape={self._shape}, ' + f'scale={self._scale}, ' + f'ratio={self._ratio}, ' + f'interpolation={self.interpolation})' + ) + + class RandomHorizontalFlip(object): r"""Horizontally flip the given Image randomly with a probability 0.5."""