33
44import numpy as np
55
6-
6+ from monai . transforms . atmostonce . apply import Apply
77from monai .transforms .atmostonce .lazy_transform import LazyTransform , compile_lazy_transforms , flatten_sequences
8- from monai .transforms .atmostonce .utility import CachedTransformCompose
8+ from monai .transforms .atmostonce .utility import CachedTransformCompose , MultiSampleTransformCompose , \
9+ IMultiSampleTransform , IRandomizableTransform , ILazyTransform
910from monai .utils import GridSampleMode , GridSamplePadMode , ensure_tuple , get_seed , MAX_SEED
1011
1112from monai .transforms import Randomizable , InvertibleTransform , OneOf , apply_transform
1213
1314
1415# TODO: this is intended to replace Compose once development is done
1516
16- class ComposeCompiler :
17- """
18- Args:
19- transforms: A sequence of callable transforms
20- lazy_resampling: Whether to resample the data after each transform or accumulate
21- changes and then resample according to the accumulated changes as few times as
22- possible. Defaults to True as this nearly always improves speed and quality
23- caching_policy: Whether to cache deterministic transforms before the first
24- randomised transforms. This can be one of "off", "drive", "memory"
25- caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will
26- favor doing more work before caching, whereas "quality" will favour delaying
27- resampling until after caching
28- """
29- def __init__ (
30- self ,
31- transforms : Union [Sequence [Callable ], Callable ],
32- lazy_resampling : Optional [bool ] = True ,
33- caching_policy : Optional [str ] = "off" ,
34- caching_favor : Optional [str ] = "quality"
35- ):
36- valid_caching_policies = ("off" , "drive" , "memory" )
37- if caching_policy not in valid_caching_policies :
38- raise ValueError ("parameter 'caching_policy' must be one of "
39- f"{ valid_caching_policies } but is '{ caching_policy } '" )
40-
41- dest_transforms = None
42-
43- if caching_policy == "off" :
44- if lazy_resampling is False :
45- dest_transforms = [t for t in transforms ]
46- else :
47- dest_transforms = ComposeCompiler .lazy_no_cache ()
48- else :
49- if caching_policy == "drive" :
50- raise NotImplementedError ()
51- elif caching_policy == "memory" :
52- raise NotImplementedError ()
53-
54- self .src_transforms = [t for t in transforms ]
55- self .dest_transforms = dest_transforms
56-
57- def __getitem__ (
58- self ,
59- index
60- ):
61- return self .dest_transforms [index ]
62-
63- def __len__ (self ):
64- return len (self .dest_transforms )
65-
66- @staticmethod
67- def lazy_no_cache (transforms ):
68- dest_transforms = []
69- # TODO: replace with lazy transform
70- cur_lazy = []
71- for i_t in range (1 , len (transforms )):
72- if isinstance (transforms [i_t ], LazyTransform ):
73- # add this to the stack of transforms to be handled lazily
74- cur_lazy .append (transforms [i_t ])
75- else :
76- if len (cur_lazy ) > 0 :
77- dest_transforms .append (cur_lazy )
78- # TODO: replace with lazy transform
79- cur_lazy = []
80- dest_transforms .append (transforms [i_t ])
81- return dest_transforms
17+ # class ComposeCompiler:
18+ # """
19+ # Args:
20+ # transforms: A sequence of callable transforms
21+ # lazy_resampling: Whether to resample the data after each transform or accumulate
22+ # changes and then resample according to the accumulated changes as few times as
23+ # possible. Defaults to True as this nearly always improves speed and quality
24+ # caching_policy: Whether to cache deterministic transforms before the first
25+ # randomised transforms. This can be one of "off", "drive", "memory"
26+ # caching_favor: Whether to cache primarily for "speed" or for "quality". "speed" will
27+ # favor doing more work before caching, whereas "quality" will favour delaying
28+ # resampling until after caching
29+ # """
30+ # def __init__(
31+ # self,
32+ # transforms: Union[Sequence[Callable], Callable],
33+ # lazy_resampling: Optional[bool] = True,
34+ # caching_policy: Optional[str] = "off",
35+ # caching_favor: Optional[str] = "quality"
36+ # ):
37+ # valid_caching_policies = ("off", "drive", "memory")
38+ # if caching_policy not in valid_caching_policies:
39+ # raise ValueError("parameter 'caching_policy' must be one of "
40+ # f"{valid_caching_policies} but is '{caching_policy}'")
41+ #
42+ # dest_transforms = None
43+ #
44+ # if caching_policy == "off":
45+ # if lazy_resampling is False:
46+ # dest_transforms = [t for t in transforms]
47+ # else:
48+ # dest_transforms = ComposeCompiler.lazy_no_cache()
49+ # else:
50+ # if caching_policy == "drive":
51+ # raise NotImplementedError()
52+ # elif caching_policy == "memory":
53+ # raise NotImplementedError()
54+ #
55+ # self.src_transforms = [t for t in transforms]
56+ # self.dest_transforms = dest_transforms
57+ #
58+ # def __getitem__(
59+ # self,
60+ # index
61+ # ):
62+ # return self.dest_transforms[index]
63+ #
64+ # def __len__(self):
65+ # return len(self.dest_transforms)
66+ #
67+ # @staticmethod
68+ # def lazy_no_cache(transforms):
69+ # dest_transforms = []
70+ # # TODO: replace with lazy transform
71+ # cur_lazy = []
72+ # for i_t in range(1, len(transforms)):
73+ # if isinstance(transforms[i_t], LazyTransform):
74+ # # add this to the stack of transforms to be handled lazily
75+ # cur_lazy.append(transforms[i_t])
76+ # else:
77+ # if len(cur_lazy) > 0:
78+ # dest_transforms.append(cur_lazy)
79+ # # TODO: replace with lazy transform
80+ # cur_lazy = []
81+ # dest_transforms.append(transforms[i_t])
82+ # return dest_transforms
8283
8384
84- class ComposeCompiler2 :
85+ class ComposeCompiler :
8586
8687 def compile (self , transforms , cache_mechanism ):
8788
@@ -93,26 +94,59 @@ def compile(self, transforms, cache_mechanism):
9394
9495 return transforms___
9596
96- def compile_caching (self , transforms , cache_stategy ):
97+ def compile_caching (self , transforms , cache_mechanism ):
98+ # TODO: handle being passed a transform list with containers
9799 # given a list of transforms, determine where to add a cached transform object
98100 # and what transforms to put in it
99- return transforms
101+ cacheable = list ()
102+ for t in transforms :
103+ if self .transform_is_random (t ) is False :
104+ cacheable .append (t )
105+ else :
106+ break
107+
108+ if len (cacheable ) == 0 :
109+ return list (transforms )
110+ else :
111+ return [CachedTransformCompose (cacheable , cache_mechanism )] + transforms [len (cacheable ):]
100112
101113 def compile_multisampling (self , transforms ):
102- return transforms
114+ for i in reversed (range (len (transforms ))):
115+ if self .transform_is_multisampling (transforms [i ]) is True :
116+ transforms_ = transforms [:i ] + [MultiSampleTransformCompose (transforms [i ],
117+ transforms [i + 1 :])]
118+ return self .compile_multisampling (transforms_ )
119+
120+ return list (transforms )
103121
104122 def compile_lazy_resampling (self , transforms ):
105- return transforms
123+ result = list ()
124+ lazy = list ()
125+ for i in range (len (transforms )):
126+ if self .transform_is_lazy (transforms [i ]):
127+ lazy .append (transforms [i ])
128+ else :
129+ if len (lazy ) > 0 :
130+ result .extend (lazy )
131+ result .append (Apply ())
132+ lazy = list ()
133+ result .append (transforms [i ])
134+ if len (lazy ) > 0 :
135+ result .extend (lazy )
136+ result .append (Apply ())
137+ return result
138+
139+ def transform_is_random (self , t ):
140+ return isinstance (t , IRandomizableTransform )
106141
107142 def transform_is_container (self , t ):
108- if isinstance (t , CachedTransform ):
109- return True
110- return False
143+ return isinstance (t , CachedTransformCompose , MultiSampleTransformCompose )
111144
112145 def transform_is_multisampling (self , t ):
113- # if isinstance(t, MultiSamplingTransform):
114- # return True
115- return False
146+ return isinstance (t , IMultiSampleTransform )
147+
148+ def transform_is_lazy (self , t ):
149+ return isinstance (t , ILazyTransform )
116150
117151
118152class Compose (Randomizable , InvertibleTransform ):
0 commit comments