Skip to content

Commit 62f8172

Browse files
committed
More work towards lazy resampling
1 parent b76e965 commit 62f8172

File tree

9 files changed

+543
-198
lines changed

9 files changed

+543
-198
lines changed

monai/data/meta_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ def __init__(
158158
def push_pending_transform(self, meta_matrix):
159159
self._pending_transforms.append(meta_matrix)
160160

161+
@property
161162
def has_pending_transforms(self):
162-
return len(self._pending_transforms)
163+
return len(self._pending_transforms) > 0
163164

164165
def peek_pending_transform(self):
165166
return copy.deepcopy(self._pending_transforms[-1])

monai/transforms/atmostonce/apply.py

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ def apply(data: Union[torch.Tensor, MetaTensor],
141141

142142
for meta_matrix in pending_:
143143
next_matrix = meta_matrix.matrix
144-
print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix))
145-
# cumulative_matrix = matmul(next_matrix, cumulative_matrix)
144+
# print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix))
146145
cumulative_matrix = matmul(cumulative_matrix, next_matrix)
147-
# cumulative_extents = [e @ translate_to_centre.matrix.matrix for e in cumulative_extents]
148146
cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents]
149147

150148
new_mode = meta_matrix.metadata.get('mode', None)
@@ -160,12 +158,10 @@ def apply(data: Union[torch.Tensor, MetaTensor],
160158

161159
if (mode_compat is False or padding_mode_compat is False or
162160
device_compat is False or dtype_compat is False):
163-
print("intermediate apply required")
164161
# carry out an intermediate resample here due to incompatibility between arguments
165162
kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype)
166163

167164
cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)
168-
print(f"intermediate applying with cumulative matrix\n {cumulative_matrix_}")
169165
a = Affine(norm_coords=False,
170166
affine=cumulative_matrix_,
171167
**kwargs)
@@ -184,7 +180,7 @@ def apply(data: Union[torch.Tensor, MetaTensor],
184180

185181
cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)
186182

187-
print(f"applying with cumulative matrix\n {cumulative_matrix_}")
183+
# print(f"applying with cumulative matrix\n {cumulative_matrix_}")
188184
a = Affine(norm_coords=False,
189185
affine=cumulative_matrix_,
190186
spatial_size=cur_shape[1:],
@@ -224,66 +220,3 @@ def __call__(
224220

225221
def inverse(self, data):
226222
return NotImplementedError()
227-
228-
229-
# class Applyd(MapTransform, InvertibleTransform):
230-
#
231-
# def __init__(self,
232-
# keys: KeysCollection,
233-
# modes: GridSampleModeSequence,
234-
# padding_modes: GridSamplePadModeSequence,
235-
# normalized: bool = False,
236-
# device: Optional[torch.device] = None,
237-
# dtypes: Optional[DtypeSequence] = np.float32):
238-
# self.keys = keys
239-
# self.modes = modes
240-
# self.padding_modes = padding_modes
241-
# self.device = device
242-
# self.dtypes = dtypes
243-
# self.resamplers = dict()
244-
#
245-
# if isinstance(dtypes, (list, tuple)):
246-
# if len(keys) != len(dtypes):
247-
# raise ValueError("'keys' and 'dtypes' must be the same length if 'dtypes' is a sequence")
248-
#
249-
# # create a resampler for each output data type
250-
# unique_resamplers = dict()
251-
# for d in dtypes:
252-
# if d not in unique_resamplers:
253-
# unique_resamplers[d] = Resample(norm_coords=not normalized, device=device, dtype=d)
254-
#
255-
# # assign each named data input the appropriate resampler for that data type
256-
# for k, d in zip(keys, dtypes):
257-
# if k not in self.resamplers:
258-
# self.resamplers[k] = unique_resamplers[d]
259-
#
260-
# else:
261-
# # share the resampler across all named data inputs
262-
# resampler = Resample(norm_coords=not normalized, device=device, dtype=dtypes)
263-
# for k in keys:
264-
# self.resamplers[k] = resampler
265-
#
266-
# def __call__(self,
267-
# data: Mapping[Hashable, NdarrayOrTensor],
268-
# allow_missing_keys: bool = False) -> Dict[Hashable, NdarrayOrTensor]:
269-
# d = dict(data)
270-
# mapping_stack = d["mappings"]
271-
# keys = d.keys()
272-
# for key_tuple in self.key_iterator(d,
273-
# expand_scalar_to_tuple(self.modes, len(keys)),
274-
# expand_scalar_to_tuple(self.padding_modes, len(keys)),
275-
# expand_scalar_to_tuple(self.dtypes, len(keys))):
276-
# key, mode, padding_mode, dtype = key_tuple
277-
# affine = mapping_stack[key].transform()
278-
# data = d[key]
279-
# spatial_size = data.shape[1:]
280-
# grid = create_grid(spatial_size, device=self.device, backend="torch", dtype=dtype)
281-
# _device = grid.device
282-
#
283-
# _b = TransformBackends.TORCH if isinstance(grid, torch.Tensor) else TransformBackends.NUMPY
284-
#
285-
# grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=grid.dtype)
286-
# affine, *_ = convert_to_dst_type(affine, grid)
287-
# d[key] = self.resamplers[key](data, grid=grid, mode=mode, padding_mode=padding_mode)
288-
#
289-
# return d

0 commit comments

Comments
 (0)