Skip to content

Remove apply_same_field #3556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions monai/transforms/smooth_field/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform):
align_corners: if True align the corners when upsampling field
prob: probability transform is applied
gamma: (min, max) range for exponential field
apply_same_field: if True, apply the same field to each key, otherwise randomize individually
device: Pytorch device to define field on
"""

Expand All @@ -66,13 +65,11 @@ def __init__(
align_corners: Optional[bool] = None,
prob: float = 0.1,
gamma: Union[Sequence[float], float] = (0.5, 4.5),
apply_same_field: bool = True,
device: Optional[torch.device] = None,
):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)

self.apply_same_field = apply_same_field
self.mode = ensure_tuple_rep(mode, len(self.keys))

self.trans = RandSmoothFieldAdjustContrast(
Expand Down Expand Up @@ -108,9 +105,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
d = dict(data)

for idx, key in enumerate(self.key_iterator(d)):
if not self.apply_same_field:
self.randomize() # new field for every key

self.trans.set_mode(self.mode[idx % len(self.mode)])
d[key] = self.trans(d[key], False)

Expand All @@ -134,7 +128,6 @@ class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform):
align_corners: if True align the corners when upsampling field
prob: probability transform is applied
gamma: (min, max) range of intensity multipliers
apply_same_field: if True, apply the same field to each key, otherwise randomize individually
device: Pytorch device to define field on
"""

Expand All @@ -150,13 +143,11 @@ def __init__(
align_corners: Optional[bool] = None,
prob: float = 0.1,
gamma: Union[Sequence[float], float] = (0.1, 1.0),
apply_same_field: bool = True,
device: Optional[torch.device] = None,
):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)

self.apply_same_field = apply_same_field
self.mode = ensure_tuple_rep(mode, len(self.keys))

self.trans = RandSmoothFieldAdjustIntensity(
Expand Down Expand Up @@ -190,9 +181,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
d = dict(data)

for idx, key in enumerate(self.key_iterator(d)):
if not self.apply_same_field:
self.randomize() # new field for every key

self.trans.set_mode(self.mode[idx % len(self.mode)])
d[key] = self.trans(d[key], False)

Expand Down Expand Up @@ -220,7 +208,6 @@ class RandSmoothDeformd(RandomizableTransform, MapTransform):
grid_mode: interpolation mode used for sampling input using deformation grid
grid_padding_mode: padding mode used for sampling input using deformation grid
grid_align_corners: if True align the corners when sampling the deformation grid
apply_same_field: if True, apply the same field to each key, otherwise randomize individually
device: Pytorch device to define field on
"""

Expand All @@ -240,15 +227,13 @@ def __init__(
grid_mode: Union[GridSampleModeType, Sequence[GridSampleModeType]] = GridSampleMode.NEAREST,
grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
grid_align_corners: Optional[bool] = False,
apply_same_field: bool = True,
device: Optional[torch.device] = None,
):
RandomizableTransform.__init__(self, prob)
MapTransform.__init__(self, keys)

self.field_mode = ensure_tuple_rep(field_mode, len(self.keys))
self.grid_mode = ensure_tuple_rep(grid_mode, len(self.keys))
self.apply_same_field = apply_same_field

self.trans = RandSmoothDeform(
rand_size=rand_size,
Expand Down Expand Up @@ -285,9 +270,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
d = dict(data)

for idx, key in enumerate(self.key_iterator(d)):
if not self.apply_same_field:
self.randomize() # new field for every key

self.trans.set_field_mode(self.field_mode[idx % len(self.field_mode)])
self.trans.set_grid_mode(self.grid_mode[idx % len(self.grid_mode)])

Expand Down