Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2872 implementation of mixup, cutmix and cutout #7198

Merged
merged 90 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
79719fa
mixup, cutmix and cutout
juampatronics Nov 3, 2023
1528cdf
added rst file
juampatronics Nov 3, 2023
ff45686
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2023
ea89145
added missing module
juampatronics Nov 3, 2023
d35c1b3
refactor code as submodule of transforms module
juampatronics Nov 6, 2023
1c54f1c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
9b05593
use the randomizable API
juampatronics Nov 7, 2023
83b2c98
used types compatible with python <3.10
juampatronics Nov 7, 2023
efd5e99
used types compatible with python <3.10
juampatronics Nov 7, 2023
f6af5cf
fixed isort errors
juampatronics Nov 7, 2023
8e7ec73
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2023
5eb8c7f
auto updates (#7203)
monai-bot Nov 6, 2023
674812f
changes from command ./runtests.sh --autofix
juampatronics Nov 14, 2023
c43f8ca
fix useless error msg in nnunetv2runner (#7217)
elitap Nov 15, 2023
cad947e
Fixup mypy 1.7.0 errors (#7231)
Shadow-Devil Nov 16, 2023
81097ab
add Yun Liu to user list to trigger blossom-ci [skip ci] (#7239)
YanxuanLiu Nov 17, 2023
da27e2e
Replace single quotation marks with double quotation marks to install…
ytl0623 Nov 17, 2023
f665179
Update bug_report.md (#7213)
dzenanz Nov 17, 2023
36beba1
Add cache option in `GridPatchDataset` (#7180)
KumoLiu Nov 17, 2023
af45bb2
:memo: [array] Add examples for EnsureType and CastToType (#7245)
ishandutta0098 Nov 19, 2023
fede45b
:hammer: [dataset] Handle corrupted cached file in PersistentDataset …
ishandutta0098 Nov 20, 2023
4bc7504
auto updates (#7247)
monai-bot Nov 20, 2023
6ec781a
add class label option to write metric report to improve readability …
elitap Nov 22, 2023
a11d8d3
Fix B026 unrecommanded star-arg unpacking after a keyword argument (#…
KumoLiu Nov 27, 2023
149a73a
Quote $PY_EXE variable to deal with Python path that contain spaces i…
ytl0623 Nov 30, 2023
308e9e2
add SoftclDiceLoss and SoftDiceclDiceLoss loss function in documentat…
ytl0623 Nov 30, 2023
be4a873
Skip Old Pytorch Versions for `SwinUNETR` (#7266)
KumoLiu Nov 30, 2023
a6e1b71
Bump conda-incubator/setup-miniconda from 2 to 3 (#7274)
dependabot[bot] Dec 1, 2023
dae5478
wholeBody_ct_segmentation failed to be download (#7280)
KumoLiu Dec 4, 2023
9802109
update the Python version requirements for transformers (#7275)
KumoLiu Dec 5, 2023
5fb9f2b
7263 add diffusion loss (#7272)
kvttt Dec 5, 2023
6a45df2
Fix swinunetrv2 2D bug (#7302)
heyufan1995 Dec 8, 2023
fc82d50
Fix `RuntimeError` in `DataAnalyzer` (#7310)
KumoLiu Dec 12, 2023
210b23a
Support specified filenames in `Saveimage` (#7318)
KumoLiu Dec 14, 2023
2fb60c1
Fix typo (#7321)
KumoLiu Dec 15, 2023
275d51f
fix optimizer pararmeter issue (#7322)
binliunls Dec 15, 2023
469db7a
Fix `lazy` ignored in `SpatialPadd` (#7316)
KumoLiu Dec 18, 2023
3f3e03c
Update openslide-python version (#7344)
KumoLiu Dec 28, 2023
71d838f
Upgrade the version of `transformers` (#7343)
KumoLiu Dec 29, 2023
80ed15f
Bump github/codeql-action from 2 to 3 (#7354)
dependabot[bot] Jan 2, 2024
c210768
Bump actions/upload-artifact from 3 to 4 (#7350)
dependabot[bot] Jan 2, 2024
285dcfc
Bump actions/setup-python from 4 to 5 (#7351)
dependabot[bot] Jan 2, 2024
a6c83d0
Bump actions/download-artifact from 3 to 4 (#7352)
dependabot[bot] Jan 2, 2024
ac56b50
Bump peter-evans/slash-command-dispatch from 3.0.1 to 3.0.2 (#7353)
dependabot[bot] Jan 3, 2024
2fa5bf1
Give more useful exception when batch is considered during matrix mul…
KumoLiu Jan 8, 2024
23ab35c
Fix incorrectly size compute in auto3dseg analyzer (#7374)
KumoLiu Jan 9, 2024
7cc7c9b
7380 mention demo in bending energy and diffusion docstrings (#7381)
kvttt Jan 10, 2024
2e382b4
Pin gdown version to v4.6.3 (#7384)
KumoLiu Jan 12, 2024
7e7d278
Fix Premerge (#7397)
KumoLiu Jan 18, 2024
ad7e3fa
Track applied operations in image filter (#7395)
vlaminckaxel Jan 18, 2024
aef4b57
Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator…
KumoLiu Jan 19, 2024
d8236ea
Fix CUDA_VISIBLE_DEVICES setting ignored (#7408)
KumoLiu Jan 22, 2024
433a3aa
Fix Incorrect updated affine in `NrrdReader` and update docstring in …
KumoLiu Jan 25, 2024
1b4091c
Ignore E704 after update black (#7422)
KumoLiu Jan 30, 2024
f8bfc7c
update `rm -rf /opt/hostedtoolcache` avoid change the python version …
KumoLiu Feb 1, 2024
69e7e05
Bump peter-evans/slash-command-dispatch from 3.0.2 to 4.0.0 (#7428)
dependabot[bot] Feb 1, 2024
bcea0e8
Bump peter-evans/create-or-update-comment from 3 to 4 (#7429)
dependabot[bot] Feb 2, 2024
748dce5
Bump actions/cache from 3 to 4 (#7430)
dependabot[bot] Feb 2, 2024
f9b4fc2
Bump codecov/codecov-action from 3 to 4 (#7431)
dependabot[bot] Feb 2, 2024
0dc013d
Update tensorboard version to fix deadlock (#7435)
KumoLiu Feb 2, 2024
c3ca41c
auto updates (#7439)
monai-bot Feb 5, 2024
449c2fb
Instantiation mode `"partial"` to `"callable"`. Return the `_target_`…
ibro45 Feb 6, 2024
eb8c8aa
Add support for mlflow experiment name in auto3dseg (#7442)
drbeh Feb 6, 2024
5ab247e
Update gdown version (#7448)
KumoLiu Feb 7, 2024
4b4c4f9
Skip "test_gaussian_filter" as a workaround for blossom killed (#7474)
KumoLiu Feb 20, 2024
d1de764
auto updates (#7463)
monai-bot Feb 20, 2024
42a4e2c
Skip "test_resize" as a workaround for blossom killed (#7484)
KumoLiu Feb 21, 2024
1394916
Fix Python 3.12 import AttributeError (#7482)
KumoLiu Feb 21, 2024
ff19822
Update test_nnunetv2runner (#7483)
KumoLiu Feb 22, 2024
5bbaab9
Fix github resource issue when build latest docker (#7450)
KumoLiu Feb 23, 2024
473593e
Use int16 instead of int8 in `LabelStats` (#7489)
KumoLiu Feb 23, 2024
01a8a24
auto updates (#7495)
monai-bot Feb 26, 2024
b0c96d8
Add sample_std parameter to RandGaussianNoise. (#7492)
bakert1 Feb 26, 2024
771af49
Add __repr__ and __str__ to Metrics baseclass (#7487)
MathijsdeBoer Feb 28, 2024
ee8bd4f
Bump al-cheb/configure-pagefile-action from 1.3 to 1.4 (#7510)
dependabot[bot] Mar 1, 2024
55be1d0
Add arm support (#7500)
KumoLiu Mar 3, 2024
6ad169a
Fix error in "test_bundle_trt_export" (#7524)
KumoLiu Mar 10, 2024
9f57cb2
Fix typo in the PerceptualNetworkType Enum (#7548)
SomeUserName1 Mar 15, 2024
5465ae3
Update to use `log_sigmoid` in `FocalLoss` (#7534)
KumoLiu Mar 18, 2024
e4a8346
Update integration_segmentation_3d result for PyTorch2403 (#7551)
KumoLiu Mar 22, 2024
a85d6a9
Add Barlow Twins loss for representation learning (#7530)
Lucas-rbnt Mar 22, 2024
1916a41
Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (#…
cxlcl Mar 22, 2024
7d48f9e
fixed code format checks
juampatronics Mar 24, 2024
292e84d
added feedback suggestions
juampatronics Nov 8, 2023
2606758
auto updates (#7577)
monai-bot Mar 25, 2024
c9a6521
flake8 warnings
juampatronics Mar 25, 2024
af42d65
DCO Remediation Commit for Juan Pablo de la Cruz Gutiérrez <juampatro…
juampatronics Mar 25, 2024
859852e
finally got sphinx format right
juampatronics Mar 25, 2024
070d963
Merge branch 'dev' into 2872-mixup
juampatronics Mar 25, 2024
a59f41f
Merge branch 'dev' into 2872-mixup
KumoLiu Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,27 @@ Post-processing
:members:
:special-members: __call__

Regularization
^^^^^^^^^^^^^^

`CutMix`
""""""""
.. autoclass:: CutMix
:members:
:special-members: __call__

`CutOut`
""""""""
.. autoclass:: CutOut
:members:
:special-members: __call__

`MixUp`
"""""""
.. autoclass:: MixUp
:members:
:special-members: __call__

Signal
^^^^^^^

Expand Down Expand Up @@ -1707,6 +1728,27 @@ Post-processing (Dict)
:members:
:special-members: __call__

Regularization (Dict)
^^^^^^^^^^^^^^^^^^^^^

`CutMixd`
"""""""""
.. autoclass:: CutMixd
:members:
:special-members: __call__

`CutOutd`
"""""""""
.. autoclass:: CutOutd
:members:
:special-members: __call__

`MixUpd`
""""""""
.. autoclass:: MixUpd
:members:
:special-members: __call__

Signal (Dict)
^^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions docs/source/transforms_idx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ Post-processing
post.array
post.dictionary

Regularization
^^^^^^^^^^^^^^

.. autosummary::
:toctree: _gen
:nosignatures:

regularization.array
regularization.dictionary

Signal
^^^^^^

Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@
VoteEnsembled,
VoteEnsembleDict,
)
from .regularization.array import CutMix, CutOut, MixUp
from .regularization.dictionary import (
CutMixd,
CutMixD,
CutMixDict,
CutOutd,
CutOutD,
CutOutDict,
MixUpd,
MixUpD,
MixUpDict,
)
from .signal.array import (
SignalContinuousWavelet,
SignalFillEmpty,
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/regularization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
173 changes: 173 additions & 0 deletions monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import abstractmethod
from math import ceil, sqrt

import torch

from ..transform import RandomizableTransform

__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]


class Mixer(RandomizableTransform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
Mixer is a base class providing the basic logic for the mixup-class of
augmentations. In all cases, we need to sample the mixing weights for each
sample (lambda in the notation used in the papers). Also, pairs of samples
being mixed are picked by randomly shuffling the batch samples.

Args:
batch_size (int): number of samples per batch. That is, samples are expected tp
be of size batchsize x channels [x depth] x height x width.
alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha)
distribution. Defaults to 1.0, the uniform distribution.
"""
super().__init__()
if alpha <= 0:
raise ValueError(f"Expected positive number, but got {alpha = }")
self.alpha = alpha
self.batch_size = batch_size

@abstractmethod
def apply(self, data: torch.Tensor):
raise NotImplementedError()

def randomize(self, data=None) -> None:
"""
Sometimes you need may to apply the same transform to different tensors.
The idea is to get a sample and then apply it with apply() as often
as needed. You need to call this method everytime you apply the transform to a new
batch.
"""
self._params = (
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
self.R.permutation(self.batch_size),
)


class MixUp(Mixer):
"""MixUp as described in:
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
mixup: Beyond Empirical Risk Minimization, ICLR 2018

Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters.
"""

def apply(self, data: torch.Tensor):
weight, perm = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")

if len(dims) not in [3, 4]:
raise ValueError("Unexpected number of dimensions")

mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
if labels is None:
return self.apply(data)
return self.apply(data), self.apply(labels)


class CutMix(Mixer):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019

Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters. Here, alpha not only determines
the mixing weight but also the size of the random rectangles used during for mixing.
Please refer to the paper for details.

The most common use case is something close to:

.. code-block:: python

cm = CutMix(batch_size=8, alpha=0.5)
for batch in loader:
images, labels = batch
augimg, auglabels = cm(images, labels)
output = model(augimg)
loss = loss_function(output, auglabels)
...

"""

def apply(self, data: torch.Tensor):
weights, perm = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data + (1 - mask) * data[perm, ...]

def apply_on_labels(self, labels: torch.Tensor):
weights, perm = self._params
nsamples, *dims = labels.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented


class CutOut(Mixer):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor.
Improved Regularization of Convolutional Neural Networks with Cutout,
arXiv:1708.04552

Class derived from :py:class:`monai.transforms.Mixer`. See corresponding
documentation for details on the constructor parameters. Here, alpha not only determines
the mixing weight but also the size of the random rectangles being cut put.
Please refer to the paper for details.
"""

def apply(self, data: torch.Tensor):
weights, _ = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data

def __call__(self, data: torch.Tensor):
self.randomize()
return self.apply(data)
97 changes: 97 additions & 0 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from monai.config import KeysCollection
from monai.utils.misc import ensure_tuple

from ..transform import MapTransform
from .array import CutMix, CutOut, MixUp

__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]


class MixUpd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.MixUp`.

Notice that the mixup transformation will be the same for all entries
for consistency, i.e. images and labels must be applied the same augmenation.
"""

def __init__(
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def __call__(self, data):
self.mixup.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixup.apply(data[k])
return result


class CutMixd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutMix`.

Notice that the mixture weights will be the same for all entries
for consistency, i.e. images and labels must be aggregated with the same weights,
but the random crops are not.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
label_keys: KeysCollection | None = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
self.mixer.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixer.apply(data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(data[k])
return result


class CutOutd(MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutOut`.

Notice that the cutout is different for every entry in the dictionary.
"""

def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)

def __call__(self, data):
result = dict(data)
self.cutout.randomize()
for k in self.keys:
result[k] = self.cutout(data[k])
return result


MixUpD = MixUpDict = MixUpd
CutMixD = CutMixDict = CutMixd
CutOutD = CutOutDict = CutOutd
Loading
Loading