Skip to content

Commit ce04eb5

Browse files
committed
simplify replace
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent c23f5a4 commit ce04eb5

File tree

3 files changed

+26
-73
lines changed

3 files changed

+26
-73
lines changed

monai/transforms/inverse.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,25 @@ def get_transform_info(self) -> dict:
8787
TraceKeys.DO_TRANSFORM: self._do_transform if hasattr(self, "_do_transform") else False,
8888
}
8989

90-
def push_transform(self, *args, **kwargs):
90+
def push_transform(self, data, *args, **kwargs):
9191
transform_info = self.get_transform_info()
92+
lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False)
93+
do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, False)
9294
if not kwargs:
9395
kwargs = {}
9496
kwargs["transform_info"] = transform_info
97+
replace = kwargs.pop("replace", False)
98+
if replace and isinstance(data, MetaTensor) and get_track_meta():
99+
if not lazy_eval:
100+
xform = self.pop_transform(data, check=False) if do_transform else {}
101+
return self.push_transform(data, extra_info=xform)
102+
elif do_transform:
103+
return self.push_transform(data, pending=data.pending_operations.pop()) # type: ignore
104+
else:
105+
return data
95106
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
96-
return TraceableTransform.track_pending_transform(*args, **kwargs)
97-
return TraceableTransform.track_transform(*args, **kwargs)
107+
return TraceableTransform.track_pending_transform(data, *args, **kwargs)
108+
return TraceableTransform.track_transform(data, *args, **kwargs)
98109

99110
@classmethod
100111
def track_transform(

monai/transforms/spatial/array.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,12 +1131,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
11311131
else:
11321132
out = convert_to_tensor(img, track_meta=get_track_meta())
11331133

1134-
if get_track_meta():
1135-
if not self.lazy_evaluation:
1136-
maybe_rot90_info = self.pop_transform(out, check=False) if self._do_transform else {}
1137-
self.push_transform(out, extra_info=maybe_rot90_info)
1138-
elif self._do_transform:
1139-
self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore
1134+
self.push_transform(out, replace=True)
11401135
return out
11411136

11421137
def inverse(self, data: torch.Tensor) -> torch.Tensor:
@@ -1261,13 +1256,7 @@ def __call__(
12611256
out = rotator(img)
12621257
else:
12631258
out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
1264-
if get_track_meta():
1265-
if not self.lazy_evaluation:
1266-
rot_info = self.pop_transform(out, check=False) if self._do_transform else {}
1267-
self.push_transform(out, extra_info=rot_info)
1268-
elif self._do_transform:
1269-
p = out.pending_operations.pop() # type: ignore
1270-
self.push_transform(out, pending=p)
1259+
self.push_transform(out, replace=True)
12711260
return out
12721261

12731262
def inverse(self, data: torch.Tensor) -> torch.Tensor:
@@ -1309,13 +1298,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
13091298
self.randomize(None)
13101299
out = self.flipper(img) if self._do_transform else img
13111300
out = convert_to_tensor(out, track_meta=get_track_meta())
1312-
if get_track_meta():
1313-
if not self.lazy_evaluation:
1314-
xform_info = self.pop_transform(out, check=False) if self._do_transform else {}
1315-
self.push_transform(out, extra_info=xform_info)
1316-
elif self._do_transform:
1317-
p = out.pending_operations.pop() # type: ignore
1318-
self.push_transform(out, pending=p)
1301+
self.push_transform(out, replace=True)
13191302
return out
13201303

13211304
def inverse(self, data: torch.Tensor) -> torch.Tensor:
@@ -1369,12 +1352,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
13691352
out = self.flipper(img)
13701353
else:
13711354
out = convert_to_tensor(img, track_meta=get_track_meta())
1372-
if get_track_meta():
1373-
if not self.lazy_evaluation:
1374-
xform = self.pop_transform(out, check=False) if self._do_transform else {}
1375-
self.push_transform(out, extra_info=xform)
1376-
elif self._do_transform:
1377-
self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore
1355+
self.push_transform(out, replace=True)
13781356
return out
13791357

13801358
def inverse(self, data: torch.Tensor) -> torch.Tensor:
@@ -1503,13 +1481,7 @@ def __call__(
15031481
)
15041482
xform.lazy_evaluation = self.lazy_evaluation
15051483
out = xform(img)
1506-
if get_track_meta():
1507-
if not self.lazy_evaluation:
1508-
z_info = self.pop_transform(out, check=False) if self._do_transform else {}
1509-
self.push_transform(out, extra_info=z_info)
1510-
elif self._do_transform:
1511-
p = out.pending_operations.pop()
1512-
self.push_transform(out, pending=p)
1484+
self.push_transform(out, replace=True)
15131485
return out # type: ignore
15141486

15151487
def inverse(self, data: torch.Tensor) -> torch.Tensor:

monai/transforms/spatial/dictionary.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -598,13 +598,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t
598598
rotator.lazy_evaluation = self.lazy_evaluation
599599
for key in self.key_iterator(d):
600600
d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta())
601-
if get_track_meta():
602-
if not self.lazy_evaluation:
603-
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
604-
self.push_transform(d[key], extra_info=xform)
605-
elif self._do_transform:
606-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
607-
601+
self.push_transform(d[key], replace=True)
608602
return d
609603

610604
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
@@ -942,12 +936,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
942936
d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore
943937
else:
944938
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
945-
if get_track_meta():
946-
if not self.lazy_evaluation:
947-
xform = self.pop_transform(d[key], check=False) if do_resampling else {}
948-
self.push_transform(d[key], extra_info=xform)
949-
elif do_resampling and isinstance(d[key], MetaTensor):
950-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
939+
self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling
940+
self.push_transform(d[key], replace=True)
951941
return d
952942

953943
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
@@ -1320,12 +1310,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
13201310
d[key] = self.flipper(d[key])
13211311
else:
13221312
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
1323-
if get_track_meta():
1324-
if not self.lazy_evaluation:
1325-
xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
1326-
self.push_transform(d[key], extra_info=xform_info)
1327-
elif self._do_transform:
1328-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
1313+
self.push_transform(d[key], replace=True)
13291314
return d
13301315

13311316
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
@@ -1386,12 +1371,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
13861371
d[key] = self.flipper(d[key], randomize=False)
13871372
else:
13881373
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
1389-
if get_track_meta():
1390-
if not self.lazy_evaluation:
1391-
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
1392-
self.push_transform(d[key], extra_info=xform)
1393-
elif self._do_transform:
1394-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
1374+
self.push_transform(d[key], replace=True)
13951375
return d
13961376

13971377
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
@@ -1564,12 +1544,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
15641544
)
15651545
else:
15661546
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
1567-
if get_track_meta():
1568-
if not self.lazy_evaluation:
1569-
rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
1570-
self.push_transform(d[key], extra_info=rot_info)
1571-
elif self._do_transform:
1572-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
1547+
self.push_transform(d[key], replace=True)
15731548
return d
15741549

15751550
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
@@ -1744,12 +1719,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
17441719
)
17451720
else:
17461721
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
1747-
if get_track_meta():
1748-
if not self.lazy_evaluation:
1749-
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
1750-
self.push_transform(d[key], extra_info=xform)
1751-
elif self._do_transform:
1752-
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
1722+
self.push_transform(d[key], replace=True)
17531723
return d
17541724

17551725
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:

0 commit comments

Comments
 (0)