Skip to content

Commit

Permalink
use the randomizable API
Browse files Browse the repository at this point in the history
  • Loading branch information
juampatronics committed Nov 7, 2023
1 parent d221cf1 commit aaa640f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
76 changes: 40 additions & 36 deletions monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,36 @@
# limitations under the License.

from abc import abstractmethod
from typing import Tuple
import torch
from monai.transforms import Transform
from monai.transforms import Transform, Randomizable
from math import sqrt, ceil

__all__ = ["MixUp", "CutMix", "CutOut"]


class Mixer(Transform):
class Mixer(Transform, Randomizable):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
self._sampler = torch.distributions.beta.Beta(alpha, alpha)
self.alpha = alpha
self.batch_size = batch_size

def sample_params(self):
@abstractmethod
def apply(cls, data: torch.Tensor):
raise NotImplementedError()

def randomize(self, data=None) -> None:
"""
Sometimes you need may to apply the same transform to different tensors.
The idea is to get a sample and then apply it with apply_mixup() as often
as needed.
The idea is to get a sample and then apply it with apply() as often
as needed. You need to call this method everytime you apply the transform to a new
batch.
"""
return self._sampler.sample((self.batch_size,)), torch.randperm(self.batch_size)

@classmethod
@abstractmethod
def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
raise NotImplementedError()
self._params = (
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
self.R.permutation(self.batch_size),
)


class MixUp(Mixer):
Expand All @@ -49,9 +51,8 @@ class MixUp(Mixer):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__(batch_size, alpha)

@classmethod
def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
weight, perm = params
def apply(self, data: torch.Tensor):
weight, perm = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
Expand All @@ -63,26 +64,24 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
return mixweight * data + (1 - mixweight) * data[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
if labels is None:
return self.apply(self.sample_params(), data)

params = self.sample_params()
return self.apply(params, data), self.apply(params, labels)
return self.apply(data)
return self.apply(data), self.apply(labels)


class CutMix(Mixer):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
"""

def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
super().__init__(batch_size, alpha)

@classmethod
def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
weights, perm = params
def apply(self, data: torch.Tensor):
weights, perm = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
Expand All @@ -96,26 +95,30 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):

return mask * data + (1 - mask) * data[perm, ...]

@classmethod
def apply_on_labels(cls, params: Tuple[torch.Tensor, torch.Tensor], labels: torch.Tensor):
return MixUp.apply(params, labels)
def apply_on_labels(self, labels: torch.Tensor):
weights, perm = self._params
nsamples, *dims = labels.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
params = self.sample_params()
augmented = self.apply(params, data)
return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented


class CutOut(Mixer):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor
Improved Regularization of Convolutional Neural Networks with Cutout
Terrance DeVries, Graham W. Taylor.
Improved Regularization of Convolutional Neural Networks with Cutout,
arXiv:1708.04552
"""

@classmethod
def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
weights, _ = params
def apply(self, data: torch.Tensor):
weights, _ = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
Expand All @@ -130,4 +133,5 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):
return mask * data

def __call__(self, data: torch.Tensor):
return self.apply(self.sample_params(), data)
self.randomize()
return self.apply(data)
11 changes: 6 additions & 5 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def __init__(
self.mixup = MixUp(batch_size, alpha)

def __call__(self, data):
self.mixup.randomize()
result = dict(data)
params = self.mixup.sample_params()
for k in self.keys:
result[k] = self.mixup.apply(params, data[k])
result[k] = self.mixup.apply(data[k])
return result


Expand Down Expand Up @@ -71,12 +71,12 @@ def __init__(
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
self.mixer.randomize()
result = dict(data)
params = self.mixer.sample_params()
for k in self.keys:
result[k] = self.mixer.apply(params, data[k])
result[k] = self.mixer.apply(data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(params, data[k])
result[k] = self.mixer.apply_on_labels(data[k])
return result


Expand All @@ -98,6 +98,7 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo

def __call__(self, data):
result = dict(data)
self.cutout.randomize()
for k in self.keys:
result[k] = self.cutout(data[k])
return result
Expand Down

0 comments on commit aaa640f

Please sign in to comment.