Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #100 from PhoenixDL/random_crop_seed
Browse files Browse the repository at this point in the history
Add BaseTransformSeeded and warnings to other transforms
  • Loading branch information
mibaumgartner authored May 18, 2020
2 parents 71cb42f + c226a23 commit fd3b79f
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 33 deletions.
8 changes: 8 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ Transformation Base Classes
:undoc-members:
:show-inheritance:

:hidden:`BaseTransformSeeded`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: BaseTransformSeeded
:members:
:undoc-members:
:show-inheritance:

:hidden:`PerSampleTransform`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
53 changes: 49 additions & 4 deletions rising/transforms/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rising.random import AbstractParameter, DiscreteParameter

__all__ = ["AbstractTransform", "BaseTransform", "PerSampleTransform",
"PerChannelTransform"]
"PerChannelTransform", "BaseTransformSeeded"]

augment_callable = Callable[[torch.Tensor], Any]
augment_axis_callable = Callable[[torch.Tensor, Union[float, Sequence]], Any]
Expand Down Expand Up @@ -119,7 +119,13 @@ def forward(self, **data) -> dict:


class BaseTransform(AbstractTransform):
"""Transform to apply a functional interface to given keys"""
"""
Transform to apply a functional interface to given keys
.. warning:: This transform should not be used
with functions which have randomness build in because it will
result in different augmentations per key.
"""

def __init__(self, augment_fn: augment_callable, *args,
keys: Sequence = ('data',), grad: bool = False,
Expand Down Expand Up @@ -165,11 +171,44 @@ def forward(self, **data) -> dict:
return data


class BaseTransformSeeded(BaseTransform):
"""
Transform to apply a functional interface to given keys and use the same
pytorch(!) seed for every key.
"""

def forward(self, **data) -> dict:
"""
Apply transformation and use same seed for every key
Args:
data: dict with tensors
Returns:
dict: dict with augmented data
"""
kwargs = {}
for k in self.property_names:
kwargs[k] = getattr(self, k)

kwargs.update(self.kwargs)

seed = torch.random.get_rng_state()
for _key in self.keys:
torch.random.set_rng_state(seed)
data[_key] = self.augment_fn(data[_key], *self.args, **kwargs)
return data


class PerSampleTransform(BaseTransform):
"""
Apply transformation to each sample in batch individually
:attr:`augment_fn` must be callable with option :attr:`out`
where results are saved in
where results are saved in.
.. warning:: This transform should not be used
with functions which have randomness build in because it will
result in different augmentations per sample and key.
"""

def forward(self, **data) -> dict:
Expand All @@ -194,7 +233,13 @@ def forward(self, **data) -> dict:


class PerChannelTransform(BaseTransform):
"""Apply transformation per channel (but still to whole batch)"""
"""
Apply transformation per channel (but still to whole batch)
.. warning:: This transform should not be used
with functions which have randomness build in because it will
result in different augmentations per channel and key.
"""

def __init__(self, augment_fn: augment_callable, per_channel: bool = False,
keys: Sequence = ('data',), grad: bool = False,
Expand Down
7 changes: 4 additions & 3 deletions rising/transforms/crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from typing import Sequence, Union
from rising.transforms.abstract import BaseTransform
from rising.transforms.abstract import BaseTransform, BaseTransformSeeded
from rising.random import AbstractParameter
from rising.transforms.functional.crop import random_crop, center_crop

Expand All @@ -17,13 +19,12 @@ def __init__(self, size: Union[int, Sequence, AbstractParameter],
grad: enable gradient computation inside transformation
**kwargs: keyword arguments passed to augment_fn
"""

super().__init__(augment_fn=center_crop, keys=keys,
grad=grad, property_names=('size', ), size=size,
**kwargs)


class RandomCrop(BaseTransform):
class RandomCrop(BaseTransformSeeded):
def __init__(self, size: Union[int, Sequence, AbstractParameter],
dist: Union[int, Sequence, AbstractParameter] = 0,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions rising/transforms/functional/crop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import random
from typing import Union, Sequence
from typing import Union, Sequence, Tuple, List

from rising.utils import check_scalar

Expand Down Expand Up @@ -60,9 +60,12 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]],
Returns:
torch.Tensor: cropped output
List[int]: top left corner used for crop
"""
if check_scalar(dist):
dist = [dist] * (data.ndim - 2)
if isinstance(dist[0], torch.Tensor):
dist = [int(i) for i in dist]
if check_scalar(size):
size = [size] * (data.ndim - 2)
if not isinstance(size[0], int):
Expand All @@ -71,6 +74,6 @@ def random_crop(data: torch.Tensor, size: Union[int, Sequence[int]],
if any([crop_dim + dist_dim >= img_dim for img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)]):
raise TypeError(f"Crop can not be realized with given size {size} and dist {dist}.")

corner = [random.randrange(0, img_dim - crop_dim - dist_dim) for
corner = [torch.randint(0, img_dim - crop_dim - dist_dim, (1,)).item() for
img_dim, crop_dim, dist_dim in zip(data.shape[2:], size, dist)]
return crop(data, corner, size)
41 changes: 33 additions & 8 deletions rising/transforms/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def __init__(self, mean: Union[float, Sequence[float]],


class Noise(PerChannelTransform):
"""Add noise to data"""
"""
Add noise to data
.. warning:: This transform will apply different noise patterns to
different keys.
"""

def __init__(self, noise_type: str, per_channel: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
Expand All @@ -131,16 +136,21 @@ def __init__(self, noise_type: str, per_channel: bool = False,
keys: keys to normalize
grad: enable gradient computation inside transformation
kwargs: keyword arguments passed to noise function
See Also
--------
:func:`torch.Tensor.normal_`, :func:`torch.Tensor.exponential_`
See Also:
:func:`torch.Tensor.normal_`, :func:`torch.Tensor.exponential_`
"""
super().__init__(augment_fn=add_noise, per_channel=per_channel, keys=keys,
grad=grad, noise_type=noise_type, **kwargs)


class ExponentialNoise(Noise):
"""Add exponential noise to data"""
"""
Add exponential noise to data
.. warning:: This transform will apply different noise patterns to
different keys.
"""

def __init__(self, lambd: float, keys: Sequence = ('data',),
grad: bool = False, **kwargs):
Expand All @@ -156,7 +166,12 @@ def __init__(self, lambd: float, keys: Sequence = ('data',),


class GaussianNoise(Noise):
"""Add gaussian noise to data"""
"""
Add gaussian noise to data
.. warning:: This transform will apply different noise patterns to
different keys.
"""

def __init__(self, mean: float, std: float, keys: Sequence = ('data',),
grad: bool = False, **kwargs):
Expand Down Expand Up @@ -193,6 +208,8 @@ class RandomValuePerChannel(PerChannelTransform):
"""
Apply augmentations which take random values as input by keyword
:attr:`value`
.. warning:: This transform will apply different values to different keys.
"""

def __init__(self, augment_fn: callable,
Expand Down Expand Up @@ -245,7 +262,11 @@ def forward(self, **data) -> dict:


class RandomAddValue(RandomValuePerChannel):
"""Increase values additively"""
"""
Increase values additively
.. warning:: This transform will apply different values to different keys.
"""

def __init__(self, random_sampler: AbstractParameter,
per_channel: bool = False,
Expand All @@ -263,7 +284,11 @@ def __init__(self, random_sampler: AbstractParameter,


class RandomScaleValue(RandomValuePerChannel):
"""Scale Values"""
"""
Scale Values
.. warning:: This transform will apply different values to different keys.
"""

def __init__(self, random_sampler: AbstractParameter,
per_channel: bool = False,
Expand Down
7 changes: 4 additions & 3 deletions tests/transforms/functional/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ def test_center_crop(self):
self.assertTrue(all([_s == s for _s in crop.shape[2:]]))

def test_random_crop(self):
random.seed(0)
h = random.randrange(0, 7)
w = random.randrange(0, 7)
torch.manual_seed(0)
h = torch.randint(0, 7, (1,)).item()
w = torch.randint(0, 7, (1,)).item()
expected = self.data[:, :, h: h + 3, w: w + 3]
torch.manual_seed(0)
crop = random_crop(self.data, size=3.)
self.assertTrue((crop == expected).all())
self.assertTrue(all([_s == 3 for _s in crop.shape[2:]]))
Expand Down
31 changes: 18 additions & 13 deletions tests/transforms/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,36 @@ class TestCrop(unittest.TestCase):
def setUp(self) -> None:
data = torch.zeros(1, 1, 10, 10)
data[:, :, 4:7, 4:7] = 1
self.batch = {"data": data}
self.batch = {"data": data, "seg": data.clone()}

def test_center_crop_transform(self):
for s in range(1, 10):
trafo = CenterCrop(s)
crop = trafo(**self.batch)["data"]
trafo = CenterCrop(s, keys=("data", "seg"))
crop = trafo(**self.batch)

expected = center_crop(self.batch["data"], s)

self.assertTrue((crop == expected).all())
self.assertTrue(all([_s == s for _s in crop.shape[2:]]))
self.assertTrue(expected.allclose(crop["data"]))
self.assertTrue(expected.allclose(crop["seg"]))
self.assertTrue(all([_s == s for _s in crop["data"].shape[2:]]))
self.assertTrue(all([_s == s for _s in crop["seg"].shape[2:]]))

def test_random_crop_transform(self):
for s in range(9):
random.seed(0)
trafo = RandomCrop(s)
crop = trafo(**self.batch)["data"]
for s in range(1, 10):
torch.manual_seed(s)
trafo = RandomCrop(s, keys=("data", "seg"))
crop = trafo(**self.batch)

random.seed(0)
_ = random.choices([0]) # internally sample size
_ = random.choices([0]) # internally sample dist
_ = random.choices([0]) # internally sample size in transform
_ = random.choices([0]) # internally sample dist in transform
torch.manual_seed(s) # seed random_crop
expected = random_crop(self.batch["data"], size=s)

self.assertTrue((crop == expected).all())
self.assertTrue(all([_s == s for _s in crop.shape[2:]]))
self.assertTrue(expected.allclose(crop["data"]))
self.assertTrue(expected.allclose(crop["seg"]))
self.assertTrue(all([_s == s for _s in crop["data"].shape[2:]]))
self.assertTrue(all([_s == s for _s in crop["seg"].shape[2:]]))

def test_center_crop_random_size_transform(self):
for _ in range(10):
Expand Down

0 comments on commit fd3b79f

Please sign in to comment.