Skip to content

Commit b76e965

Browse files
committed
Compose compile; initial multisample generic croppad (array and dict) implementations
1 parent 6295e02 commit b76e965

File tree

5 files changed

+456
-93
lines changed

5 files changed

+456
-93
lines changed

monai/transforms/atmostonce/apply.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ def matrix_from_matrix_container(matrix):
110110

111111
def apply(data: Union[torch.Tensor, MetaTensor],
112112
pending: Optional[dict] = None):
113+
114+
if isinstance(data, dict):
115+
rd = dict()
116+
for k, v in data.items():
117+
result = apply(v)
118+
rd[k] = result
119+
return rd
120+
113121
pending_ = pending
114122
pending_ = data.pending_transforms
115123

@@ -188,17 +196,36 @@ def apply(data: Union[torch.Tensor, MetaTensor],
188196
return data
189197

190198

199+
# make Apply universal for arrays and dictionaries; it just calls through to functional apply
191200
class Apply(InvertibleTransform):
192201

193202
def __init__(self):
194203
super().__init__()
195204

205+
def __call__(self, *args, **kwargs):
206+
return apply(*args, **kwargs)
207+
208+
def inverse(self, data):
209+
return NotImplementedError()
210+
196211

197212
class Applyd(MapTransform, InvertibleTransform):
198213

199214
def __init__(self):
200215
super().__init__()
201216

217+
def __call__(
218+
self,
219+
d: dict
220+
):
221+
rd = dict()
222+
for k, v in d.items():
223+
rd[k] = apply(v)
224+
225+
def inverse(self, data):
226+
return NotImplementedError()
227+
228+
202229
# class Applyd(MapTransform, InvertibleTransform):
203230
#
204231
# def __init__(self,

monai/transforms/atmostonce/array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -737,15 +737,20 @@ def __call__(
737737
img: torch.Tensor,
738738
randomize: Optional[bool] = True
739739
):
740+
741+
img_shape = img.shape[:1]
742+
740743
if randomize:
741744
self.randomize(img)
742745

743746
if self._do_transform:
744747
offsets_ = self.offsets
745-
slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes))
746-
return self.op(img, slices=slices)
747748
else:
748-
return self.op(img)
749+
# center crop if this sample isn't random
750+
offsets_ = tuple((i - s) // 2 for i, s in zip(img_shape, self.sizes))
751+
slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes))
752+
return self.op(img, slices=slices)
753+
749754

750755
def inverse(
751756
self,

monai/transforms/atmostonce/compose.py

Lines changed: 113 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,85 +3,86 @@
33

44
import numpy as np
55

6-
6+
from monai.transforms.atmostonce.apply import Apply
77
from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_lazy_transforms, flatten_sequences
8-
from monai.transforms.atmostonce.utility import CachedTransformCompose
8+
from monai.transforms.atmostonce.utility import CachedTransformCompose, MultiSampleTransformCompose, \
9+
IMultiSampleTransform, IRandomizableTransform, ILazyTransform
910
from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED
1011

1112
from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform
1213

1314

1415
# TODO: this is intended to replace Compose once development is done
1516

16-
class ComposeCompiler:
17-
"""
18-
Args:
19-
transforms: A sequence of callable transforms
20-
lazy_resampling: Whether to resample the data after each transform or accumulate
21-
changes and then resample according to the accumulated changes as few times as
22-
possible. Defaults to True as this nearly always improves speed and quality
23-
caching_policy: Whether to cache deterministic transforms before the first
24-
randomised transforms. This can be one of "off", "drive", "memory"
25-
caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will
26-
favor doing more work before caching, whereas "quality" will favour delaying
27-
resampling until after caching
28-
"""
29-
def __init__(
30-
self,
31-
transforms: Union[Sequence[Callable], Callable],
32-
lazy_resampling: Optional[bool] = True,
33-
caching_policy: Optional[str] = "off",
34-
caching_favor: Optional[str] = "quality"
35-
):
36-
valid_caching_policies = ("off", "drive", "memory")
37-
if caching_policy not in valid_caching_policies:
38-
raise ValueError("parameter 'caching_policy' must be one of "
39-
f"{valid_caching_policies} but is '{caching_policy}'")
40-
41-
dest_transforms = None
42-
43-
if caching_policy == "off":
44-
if lazy_resampling is False:
45-
dest_transforms = [t for t in transforms]
46-
else:
47-
dest_transforms = ComposeCompiler.lazy_no_cache()
48-
else:
49-
if caching_policy == "drive":
50-
raise NotImplementedError()
51-
elif caching_policy == "memory":
52-
raise NotImplementedError()
53-
54-
self.src_transforms = [t for t in transforms]
55-
self.dest_transforms = dest_transforms
56-
57-
def __getitem__(
58-
self,
59-
index
60-
):
61-
return self.dest_transforms[index]
62-
63-
def __len__(self):
64-
return len(self.dest_transforms)
65-
66-
@staticmethod
67-
def lazy_no_cache(transforms):
68-
dest_transforms = []
69-
# TODO: replace with lazy transform
70-
cur_lazy = []
71-
for i_t in range(1, len(transforms)):
72-
if isinstance(transforms[i_t], LazyTransform):
73-
# add this to the stack of transforms to be handled lazily
74-
cur_lazy.append(transforms[i_t])
75-
else:
76-
if len(cur_lazy) > 0:
77-
dest_transforms.append(cur_lazy)
78-
# TODO: replace with lazy transform
79-
cur_lazy = []
80-
dest_transforms.append(transforms[i_t])
81-
return dest_transforms
17+
# class ComposeCompiler:
18+
# """
19+
# Args:
20+
# transforms: A sequence of callable transforms
21+
# lazy_resampling: Whether to resample the data after each transform or accumulate
22+
# changes and then resample according to the accumulated changes as few times as
23+
# possible. Defaults to True as this nearly always improves speed and quality
24+
# caching_policy: Whether to cache deterministic transforms before the first
25+
# randomised transforms. This can be one of "off", "drive", "memory"
26+
# caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will
27+
# favor doing more work before caching, whereas "quality" will favour delaying
28+
# resampling until after caching
29+
# """
30+
# def __init__(
31+
# self,
32+
# transforms: Union[Sequence[Callable], Callable],
33+
# lazy_resampling: Optional[bool] = True,
34+
# caching_policy: Optional[str] = "off",
35+
# caching_favor: Optional[str] = "quality"
36+
# ):
37+
# valid_caching_policies = ("off", "drive", "memory")
38+
# if caching_policy not in valid_caching_policies:
39+
# raise ValueError("parameter 'caching_policy' must be one of "
40+
# f"{valid_caching_policies} but is '{caching_policy}'")
41+
#
42+
# dest_transforms = None
43+
#
44+
# if caching_policy == "off":
45+
# if lazy_resampling is False:
46+
# dest_transforms = [t for t in transforms]
47+
# else:
48+
# dest_transforms = ComposeCompiler.lazy_no_cache()
49+
# else:
50+
# if caching_policy == "drive":
51+
# raise NotImplementedError()
52+
# elif caching_policy == "memory":
53+
# raise NotImplementedError()
54+
#
55+
# self.src_transforms = [t for t in transforms]
56+
# self.dest_transforms = dest_transforms
57+
#
58+
# def __getitem__(
59+
# self,
60+
# index
61+
# ):
62+
# return self.dest_transforms[index]
63+
#
64+
# def __len__(self):
65+
# return len(self.dest_transforms)
66+
#
67+
# @staticmethod
68+
# def lazy_no_cache(transforms):
69+
# dest_transforms = []
70+
# # TODO: replace with lazy transform
71+
# cur_lazy = []
72+
# for i_t in range(1, len(transforms)):
73+
# if isinstance(transforms[i_t], LazyTransform):
74+
# # add this to the stack of transforms to be handled lazily
75+
# cur_lazy.append(transforms[i_t])
76+
# else:
77+
# if len(cur_lazy) > 0:
78+
# dest_transforms.append(cur_lazy)
79+
# # TODO: replace with lazy transform
80+
# cur_lazy = []
81+
# dest_transforms.append(transforms[i_t])
82+
# return dest_transforms
8283

8384

84-
class ComposeCompiler2:
85+
class ComposeCompiler:
8586

8687
def compile(self, transforms, cache_mechanism):
8788

@@ -93,26 +94,59 @@ def compile(self, transforms, cache_mechanism):
9394

9495
return transforms___
9596

96-
def compile_caching(self, transforms, cache_stategy):
97+
def compile_caching(self, transforms, cache_mechanism):
98+
# TODO: handle being passed a transform list with containers
9799
# given a list of transforms, determine where to add a cached transform object
98100
# and what transforms to put in it
99-
return transforms
101+
cacheable = list()
102+
for t in transforms:
103+
if self.transform_is_random(t) is False:
104+
cacheable.append(t)
105+
else:
106+
break
107+
108+
if len(cacheable) == 0:
109+
return list(transforms)
110+
else:
111+
return [CachedTransformCompose(cacheable, cache_mechanism)] + transforms[len(cacheable):]
100112

101113
def compile_multisampling(self, transforms):
102-
return transforms
114+
for i in reversed(range(len(transforms))):
115+
if self.transform_is_multisampling(transforms[i]) is True:
116+
transforms_ = transforms[:i] + [MultiSampleTransformCompose(transforms[i],
117+
transforms[i+1:])]
118+
return self.compile_multisampling(transforms_)
119+
120+
return list(transforms)
103121

104122
def compile_lazy_resampling(self, transforms):
105-
return transforms
123+
result = list()
124+
lazy = list()
125+
for i in range(len(transforms)):
126+
if self.transform_is_lazy(transforms[i]):
127+
lazy.append(transforms[i])
128+
else:
129+
if len(lazy) > 0:
130+
result.extend(lazy)
131+
result.append(Apply())
132+
lazy = list()
133+
result.append(transforms[i])
134+
if len(lazy) > 0:
135+
result.extend(lazy)
136+
result.append(Apply())
137+
return result
138+
139+
def transform_is_random(self, t):
140+
return isinstance(t, IRandomizableTransform)
106141

107142
def transform_is_container(self, t):
108-
if isinstance(t, CachedTransform):
109-
return True
110-
return False
143+
return isinstance(t, CachedTransformCompose, MultiSampleTransformCompose)
111144

112145
def transform_is_multisampling(self, t):
113-
# if isinstance(t, MultiSamplingTransform):
114-
# return True
115-
return False
146+
return isinstance(t, IMultiSampleTransform)
147+
148+
def transform_is_lazy(self, t):
149+
return isinstance(t, ILazyTransform)
116150

117151

118152
class Compose(Randomizable, InvertibleTransform):

0 commit comments

Comments
 (0)