From 4b38fa133ec9863506713e0df532a586e2a985f5 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Fri, 10 Apr 2020 17:15:24 -0700 Subject: [PATCH] mixup data augmentation (#469) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/469 This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Empirically, it is much faster to do mixup transform on gpu than doing that on cpu. # Results accuracy gain - 1.0% with 135 training epochs - 1.3% with 270 training epochs [TODO]: fix accuracy meter at training phases. Reviewed By: mannatsingh Differential Revision: D20911088 fbshipit-source-id: bc658b1b8e12b9b5819903d93e07c73bcbe6ddd8 --- classy_vision/dataset/transforms/mixup.py | 48 ++++++++++++++++++ classy_vision/generic/util.py | 11 ++-- .../losses/soft_target_cross_entropy_loss.py | 9 +++- classy_vision/meters/accuracy_meter.py | 2 +- classy_vision/tasks/classification_task.py | 26 ++++++++++ test/dataset_transforms_mixup_test.py | 50 +++++++++++++++++++ ...ses_soft_target_cross_entropy_loss_test.py | 7 +++ 7 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 classy_vision/dataset/transforms/mixup.py create mode 100644 test/dataset_transforms_mixup_test.py diff --git a/classy_vision/dataset/transforms/mixup.py b/classy_vision/dataset/transforms/mixup.py new file mode 100644 index 0000000000..82be963686 --- /dev/null +++ b/classy_vision/dataset/transforms/mixup.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional + +import torch +from classy_vision.generic.util import convert_to_one_hot +from torch.distributions.beta import Beta + + +class MixupTransform: + """ + This implements the mixup data augmentation in the paper + "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412) + """ + + def __init__(self, alpha: float, num_classes: Optional[int] = None): + """ + Args: + alpha: the hyperparameter of Beta distribution used to sample mixup + coefficient. + num_classes: number of classes in the dataset. + """ + self.alpha = alpha + self.num_classes = num_classes + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Args: + sample: the batch data. + """ + if sample["target"].ndim == 1: + assert self.num_classes is not None, "num_classes is expected for 1D target" + sample["target"] = convert_to_one_hot( + sample["target"].view(-1, 1), self.num_classes + ) + else: + assert sample["target"].ndim == 2, "target tensor shape must be 1D or 2D" + + c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device) + permuted_indices = torch.randperm(sample["target"].shape[0]) + for key in ["input", "target"]: + sample[key] = c * sample[key] + (1.0 - c) * sample[key][permuted_indices, :] + + return sample diff --git a/classy_vision/generic/util.py b/classy_vision/generic/util.py index cf413e4420..1fb56a9a90 100644 --- a/classy_vision/generic/util.py +++ b/classy_vision/generic/util.py @@ -736,12 +736,11 @@ def maybe_convert_to_one_hot(target, model_output): ): target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1]) - assert (target.shape == model_output.shape) and ( - torch.min(target.eq(0) + target.eq(1)) == 1 - ), ( - "Target must be one-hot/multi-label encoded and of the " - "same shape as model_output." - ) + # target are not necessarily hard 0/1 encoding. It can be soft + # (i.e. fractional) in some cases, such as mixup label + assert ( + target.shape == model_output.shape + ), "Target must of the same shape as model_output." return target diff --git a/classy_vision/losses/soft_target_cross_entropy_loss.py b/classy_vision/losses/soft_target_cross_entropy_loss.py index a8d67e7a8a..91350e4520 100644 --- a/classy_vision/losses/soft_target_cross_entropy_loss.py +++ b/classy_vision/losses/soft_target_cross_entropy_loss.py @@ -10,6 +10,7 @@ import numpy as np import torch import torch.nn.functional as F +from classy_vision.generic.util import convert_to_one_hot from classy_vision.losses import ClassyLoss, register_loss @@ -58,13 +59,19 @@ def from_config(cls, config: Dict[str, Any]) -> "SoftTargetCrossEntropyLoss": def forward(self, output, target): """for N examples and C classes - output: N x C these are raw outputs (without softmax/sigmoid) - - target: N x C corresponding targets + - target: N x C or N corresponding targets Target elements set to ignore_index contribute 0 loss. Samples where all entries are ignore_index do not contribute to the loss reduction. """ + # check if targets are inputted as class integers + if target.ndim == 1: + assert ( + output.shape[0] == target.shape[0] + ), "SoftTargetCrossEntropyLoss requires output and target to have same batch size" + target = convert_to_one_hot(target.view(-1, 1), output.shape[1]) assert ( output.shape == target.shape ), "SoftTargetCrossEntropyLoss requires output and target to be same" diff --git a/classy_vision/meters/accuracy_meter.py b/classy_vision/meters/accuracy_meter.py index 0936431d1d..4738cd1d05 100644 --- a/classy_vision/meters/accuracy_meter.py +++ b/classy_vision/meters/accuracy_meter.py @@ -145,7 +145,7 @@ def update(self, model_output, target, **kwargs): for i, k in enumerate(self._topk): self._curr_correct_predictions_k[i] += ( torch.gather(target, dim=1, index=pred[:, :k]) - .long() + # .long() .max(dim=1) .values.sum() .item() diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index da29e029b4..ad4538e0e3 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from classy_vision.dataset import ClassyDataset, build_dataset +from classy_vision.dataset.transforms.mixup import MixupTransform from classy_vision.generic.distributed_util import ( all_reduce_mean, barrier, @@ -141,6 +142,7 @@ def __init__(self): BroadcastBuffersMode.DISABLED ) self.amp_args = None + self.mixup_transform = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED @@ -326,6 +328,19 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]): logging.info(f"AMP enabled with args {amp_args}") return self + def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]): + """Disable / enable mixup transform for data augmentation + + Args:: + mixup_transform: a callable object which performs mixup data augmentation + """ + self.mixup_transform = mixup_transform + if mixup_transform is None: + logging.info(f"mixup disabled") + else: + logging.info(f"mixup enabled") + return self + @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. @@ -353,6 +368,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": meters = build_meters(config.get("meters", {})) model = build_model(config["model"]) + mixup_transform = None + if config.get("mixup") is not None: + assert "alpha" in config["mixup"], "key alpha is missing in mixup dict" + mixup_transform = MixupTransform( + config["mixup"]["alpha"], config["mixup"].get("num_classes") + ) + # hooks config is optional hooks_config = config.get("hooks") hooks = [] @@ -371,6 +393,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_optimizer(optimizer) .set_meters(meters) .set_amp_args(amp_args) + .set_mixup_transform(mixup_transform) .set_distributed_options( broadcast_buffers_mode=BroadcastBuffersMode[ config.get("broadcast_buffers", "disabled").upper() @@ -775,6 +798,9 @@ def train_step(self): for key, value in sample.items(): sample[key] = recursive_copy_to_gpu(value, non_blocking=True) + if self.mixup_transform is not None: + sample = self.mixup_transform(sample) + with torch.enable_grad(): # Forward pass output = self.model(sample["input"]) diff --git a/test/dataset_transforms_mixup_test.py b/test/dataset_transforms_mixup_test.py new file mode 100644 index 0000000000..d77248388a --- /dev/null +++ b/test/dataset_transforms_mixup_test.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch +from classy_vision.dataset.transforms.mixup import MixupTransform + + +class DatasetTransformsMixupTest(unittest.TestCase): + def test_mixup_transform_single_label(self): + alpha = 2.0 + num_classes = 3 + mixup_transform = MixupTransform(alpha, num_classes) + sample = { + "input": torch.rand(4, 3, 224, 224, dtype=torch.float32), + "target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32), + } + sample_mixup = mixup_transform(sample) + self.assertTrue(sample["input"].shape == sample_mixup["input"].shape) + self.assertTrue(sample_mixup["target"].shape[0] == 4) + self.assertTrue(sample_mixup["target"].shape[1] == 3) + + def test_mixup_transform_single_label_missing_num_classes(self): + alpha = 2.0 + mixup_transform = MixupTransform(alpha, None) + sample = { + "input": torch.rand(4, 3, 224, 224, dtype=torch.float32), + "target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32), + } + with self.assertRaises(Exception): + mixup_transform(sample) + + def test_mixup_transform_multi_label(self): + alpha = 2.0 + mixup_transform = MixupTransform(alpha, None) + sample = { + "input": torch.rand(4, 3, 224, 224, dtype=torch.float32), + "target": torch.as_tensor( + [[1, 0, 0, 0], [0, 1, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]], + dtype=torch.int32, + ), + } + sample_mixup = mixup_transform(sample) + self.assertTrue(sample["input"].shape == sample_mixup["input"].shape) + self.assertTrue(sample["target"].shape == sample_mixup["target"].shape) diff --git a/test/losses_soft_target_cross_entropy_loss_test.py b/test/losses_soft_target_cross_entropy_loss_test.py index 47374ecdb3..ab2fea58d6 100644 --- a/test/losses_soft_target_cross_entropy_loss_test.py +++ b/test/losses_soft_target_cross_entropy_loss_test.py @@ -47,6 +47,13 @@ def test_soft_target_cross_entropy(self): targets = torch.tensor([[-1, 0, 0, 0, 1]]) self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918) + def test_soft_target_cross_entropy_integer_label(self): + config = self._get_config() + crit = SoftTargetCrossEntropyLoss.from_config(config) + outputs = self._get_outputs() + targets = torch.tensor([4]) + self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918) + def test_unnormalized_soft_target_cross_entropy(self): config = { "name": "soft_target_cross_entropy",