Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

mixup data augmentation #469

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions classy_vision/dataset/transforms/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
from classy_vision.generic.util import convert_to_one_hot
from torch.distributions.beta import Beta


class MixupTransform:
"""
This implements the mixup data augmentation in the paper
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412)
"""

def __init__(self, alpha: float, num_classes: Optional[int] = None):
"""
Args:
alpha: the hyperparameter of Beta distribution used to sample mixup
coefficient.
num_classes: number of classes in the dataset.
"""
self.alpha = alpha
self.num_classes = num_classes

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
sample: the batch data.
"""
if sample["target"].ndim == 1:
assert self.num_classes is not None, "num_classes is expected for 1D target"
sample["target"] = convert_to_one_hot(
sample["target"].view(-1, 1), self.num_classes
)
else:
assert sample["target"].ndim == 2, "target tensor shape must be 1D or 2D"

c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device)
permuted_indices = torch.randperm(sample["target"].shape[0])
for key in ["input", "target"]:
sample[key] = c * sample[key] + (1.0 - c) * sample[key][permuted_indices, :]

return sample
11 changes: 5 additions & 6 deletions classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,11 @@ def maybe_convert_to_one_hot(target, model_output):
):
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])

assert (target.shape == model_output.shape) and (
torch.min(target.eq(0) + target.eq(1)) == 1
), (
"Target must be one-hot/multi-label encoded and of the "
"same shape as model_output."
)
# target are not necessarily hard 0/1 encoding. It can be soft
# (i.e. fractional) in some cases, such as mixup label
assert (
target.shape == model_output.shape
), "Target must of the same shape as model_output."

return target

Expand Down
9 changes: 8 additions & 1 deletion classy_vision/losses/soft_target_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from classy_vision.generic.util import convert_to_one_hot
from classy_vision.losses import ClassyLoss, register_loss


Expand Down Expand Up @@ -58,13 +59,19 @@ def from_config(cls, config: Dict[str, Any]) -> "SoftTargetCrossEntropyLoss":
def forward(self, output, target):
"""for N examples and C classes
- output: N x C these are raw outputs (without softmax/sigmoid)
- target: N x C corresponding targets
- target: N x C or N corresponding targets

Target elements set to ignore_index contribute 0 loss.

Samples where all entries are ignore_index do not contribute to the loss
reduction.
"""
# check if targets are inputted as class integers
if target.ndim == 1:
assert (
output.shape[0] == target.shape[0]
), "SoftTargetCrossEntropyLoss requires output and target to have same batch size"
target = convert_to_one_hot(target.view(-1, 1), output.shape[1])
assert (
output.shape == target.shape
), "SoftTargetCrossEntropyLoss requires output and target to be same"
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/meters/accuracy_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def update(self, model_output, target, **kwargs):
for i, k in enumerate(self._topk):
self._curr_correct_predictions_k[i] += (
torch.gather(target, dim=1, index=pred[:, :k])
.long()
# .long()
.max(dim=1)
.values.sum()
.item()
Expand Down
26 changes: 26 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.nn as nn
from classy_vision.dataset import ClassyDataset, build_dataset
from classy_vision.dataset.transforms.mixup import MixupTransform
from classy_vision.generic.distributed_util import (
all_reduce_mean,
barrier,
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(self):
BroadcastBuffersMode.DISABLED
)
self.amp_args = None
self.mixup_transform = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
Expand Down Expand Up @@ -326,6 +328,19 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
logging.info(f"AMP enabled with args {amp_args}")
return self

def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
"""Disable / enable mixup transform for data augmentation

Args::
mixup_transform: a callable object which performs mixup data augmentation
"""
self.mixup_transform = mixup_transform
if mixup_transform is None:
logging.info(f"mixup disabled")
else:
logging.info(f"mixup enabled")
return self

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
"""Instantiates a ClassificationTask from a configuration.
Expand Down Expand Up @@ -353,6 +368,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
meters = build_meters(config.get("meters", {}))
model = build_model(config["model"])

mixup_transform = None
if config.get("mixup") is not None:
assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
mixup_transform = MixupTransform(
config["mixup"]["alpha"], config["mixup"].get("num_classes")
)

# hooks config is optional
hooks_config = config.get("hooks")
hooks = []
Expand All @@ -371,6 +393,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_optimizer(optimizer)
.set_meters(meters)
.set_amp_args(amp_args)
.set_mixup_transform(mixup_transform)
.set_distributed_options(
broadcast_buffers_mode=BroadcastBuffersMode[
config.get("broadcast_buffers", "disabled").upper()
Expand Down Expand Up @@ -775,6 +798,9 @@ def train_step(self):
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

if self.mixup_transform is not None:
sample = self.mixup_transform(sample)

with torch.enable_grad():
# Forward pass
output = self.model(sample["input"])
Expand Down
50 changes: 50 additions & 0 deletions test/dataset_transforms_mixup_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import unittest

import torch
from classy_vision.dataset.transforms.mixup import MixupTransform


class DatasetTransformsMixupTest(unittest.TestCase):
def test_mixup_transform_single_label(self):
alpha = 2.0
num_classes = 3
mixup_transform = MixupTransform(alpha, num_classes)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
}
sample_mixup = mixup_transform(sample)
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
self.assertTrue(sample_mixup["target"].shape[0] == 4)
self.assertTrue(sample_mixup["target"].shape[1] == 3)

def test_mixup_transform_single_label_missing_num_classes(self):
alpha = 2.0
mixup_transform = MixupTransform(alpha, None)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
}
with self.assertRaises(Exception):
mixup_transform(sample)

def test_mixup_transform_multi_label(self):
alpha = 2.0
mixup_transform = MixupTransform(alpha, None)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor(
[[1, 0, 0, 0], [0, 1, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]],
dtype=torch.int32,
),
}
sample_mixup = mixup_transform(sample)
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
self.assertTrue(sample["target"].shape == sample_mixup["target"].shape)
7 changes: 7 additions & 0 deletions test/losses_soft_target_cross_entropy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def test_soft_target_cross_entropy(self):
targets = torch.tensor([[-1, 0, 0, 0, 1]])
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)

def test_soft_target_cross_entropy_integer_label(self):
config = self._get_config()
crit = SoftTargetCrossEntropyLoss.from_config(config)
outputs = self._get_outputs()
targets = torch.tensor([4])
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)

def test_unnormalized_soft_target_cross_entropy(self):
config = {
"name": "soft_target_cross_entropy",
Expand Down