diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py index 109fd414f49..d244da24a95 100644 --- a/monai/transforms/regularization/array.py +++ b/monai/transforms/regularization/array.py @@ -10,34 +10,36 @@ # limitations under the License. from abc import abstractmethod -from typing import Tuple import torch -from monai.transforms import Transform +from monai.transforms import Transform, Randomizable from math import sqrt, ceil __all__ = ["MixUp", "CutMix", "CutOut"] -class Mixer(Transform): +class Mixer(Transform, Randomizable): def __init__(self, batch_size: int, alpha: float = 1.0) -> None: super().__init__() if alpha <= 0: raise ValueError(f"Expected positive number, but got {alpha = }") - self._sampler = torch.distributions.beta.Beta(alpha, alpha) + self.alpha = alpha self.batch_size = batch_size - def sample_params(self): + @abstractmethod + def apply(cls, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: """ Sometimes you need may to apply the same transform to different tensors. - The idea is to get a sample and then apply it with apply_mixup() as often - as needed. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. """ - return self._sampler.sample((self.batch_size,)), torch.randperm(self.batch_size) - - @classmethod - @abstractmethod - def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): - raise NotImplementedError() + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) class MixUp(Mixer): @@ -49,9 +51,8 @@ class MixUp(Mixer): def __init__(self, batch_size: int, alpha: float = 1.0) -> None: super().__init__(batch_size, alpha) - @classmethod - def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): - weight, perm = params + def apply(self, data: torch.Tensor): + weight, perm = self._params nsamples, *dims = data.shape if len(weight) != nsamples: raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") @@ -63,16 +64,15 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): return mixweight * data + (1 - mixweight) * data[perm, ...] def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() if labels is None: - return self.apply(self.sample_params(), data) - - params = self.sample_params() - return self.apply(params, data), self.apply(params, labels) + return self.apply(data) + return self.apply(data), self.apply(labels) class CutMix(Mixer): """CutMix augmentation as described in: - Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, ICCV 2019 """ @@ -80,9 +80,8 @@ class CutMix(Mixer): def __init__(self, batch_size: int, alpha: float = 1.0) -> None: super().__init__(batch_size, alpha) - @classmethod - def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): - weights, perm = params + def apply(self, data: torch.Tensor): + weights, perm = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") @@ -96,26 +95,30 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): return mask * data + (1 - mask) * data[perm, ...] - @classmethod - def apply_on_labels(cls, params: Tuple[torch.Tensor, torch.Tensor], labels: torch.Tensor): - return MixUp.apply(params, labels) + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): - params = self.sample_params() - augmented = self.apply(params, data) - return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented class CutOut(Mixer): """Cutout as described in the paper: - Terrance DeVries, Graham W. Taylor - Improved Regularization of Convolutional Neural Networks with Cutout + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, arXiv:1708.04552 """ - @classmethod - def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): - weights, _ = params + def apply(self, data: torch.Tensor): + weights, _ = self._params nsamples, _, *dims = data.shape if len(weights) != nsamples: raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") @@ -130,4 +133,5 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor): return mask * data def __call__(self, data: torch.Tensor): - return self.apply(self.sample_params(), data) + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py index fab55be2b3b..99414653b27 100644 --- a/monai/transforms/regularization/dictionary.py +++ b/monai/transforms/regularization/dictionary.py @@ -37,10 +37,10 @@ def __init__( self.mixup = MixUp(batch_size, alpha) def __call__(self, data): + self.mixup.randomize() result = dict(data) - params = self.mixup.sample_params() for k in self.keys: - result[k] = self.mixup.apply(params, data[k]) + result[k] = self.mixup.apply(data[k]) return result @@ -71,12 +71,12 @@ def __init__( self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] def __call__(self, data): + self.mixer.randomize() result = dict(data) - params = self.mixer.sample_params() for k in self.keys: - result[k] = self.mixer.apply(params, data[k]) + result[k] = self.mixer.apply(data[k]) for k in self.label_keys: - result[k] = self.mixer.apply_on_labels(params, data[k]) + result[k] = self.mixer.apply_on_labels(data[k]) return result @@ -98,6 +98,7 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo def __call__(self, data): result = dict(data) + self.cutout.randomize() for k in self.keys: result[k] = self.cutout(data[k]) return result