@@ -598,13 +598,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t
598598 rotator .lazy_evaluation = self .lazy_evaluation
599599 for key in self .key_iterator (d ):
600600 d [key ] = rotator (d [key ]) if self ._do_transform else convert_to_tensor (d [key ], track_meta = get_track_meta ())
601- if get_track_meta ():
602- if not self .lazy_evaluation :
603- xform = self .pop_transform (d [key ], check = False ) if self ._do_transform else {}
604- self .push_transform (d [key ], extra_info = xform )
605- elif self ._do_transform :
606- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
607-
601+ self .push_transform (d [key ], replace = True )
608602 return d
609603
610604 def inverse (self , data : Mapping [Hashable , torch .Tensor ]) -> dict [Hashable , torch .Tensor ]:
@@ -942,12 +936,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
942936 d [key ] = self .rand_affine (d [key ], mode = mode , padding_mode = padding_mode , grid = grid ) # type: ignore
943937 else :
944938 d [key ] = convert_to_tensor (d [key ], track_meta = get_track_meta (), dtype = torch .float32 )
945- if get_track_meta ():
946- if not self .lazy_evaluation :
947- xform = self .pop_transform (d [key ], check = False ) if do_resampling else {}
948- self .push_transform (d [key ], extra_info = xform )
949- elif do_resampling and isinstance (d [key ], MetaTensor ):
950- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
939+ self ._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling
940+ self .push_transform (d [key ], replace = True )
951941 return d
952942
953943 def inverse (self , data : Mapping [Hashable , NdarrayOrTensor ]) -> dict [Hashable , NdarrayOrTensor ]:
@@ -1320,12 +1310,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
13201310 d [key ] = self .flipper (d [key ])
13211311 else :
13221312 d [key ] = convert_to_tensor (d [key ], track_meta = get_track_meta ())
1323- if get_track_meta ():
1324- if not self .lazy_evaluation :
1325- xform_info = self .pop_transform (d [key ], check = False ) if self ._do_transform else {}
1326- self .push_transform (d [key ], extra_info = xform_info )
1327- elif self ._do_transform :
1328- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
1313+ self .push_transform (d [key ], replace = True )
13291314 return d
13301315
13311316 def inverse (self , data : Mapping [Hashable , torch .Tensor ]) -> dict [Hashable , torch .Tensor ]:
@@ -1386,12 +1371,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
13861371 d [key ] = self .flipper (d [key ], randomize = False )
13871372 else :
13881373 d [key ] = convert_to_tensor (d [key ], track_meta = get_track_meta ())
1389- if get_track_meta ():
1390- if not self .lazy_evaluation :
1391- xform = self .pop_transform (d [key ], check = False ) if self ._do_transform else {}
1392- self .push_transform (d [key ], extra_info = xform )
1393- elif self ._do_transform :
1394- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
1374+ self .push_transform (d [key ], replace = True )
13951375 return d
13961376
13971377 def inverse (self , data : Mapping [Hashable , torch .Tensor ]) -> dict [Hashable , torch .Tensor ]:
@@ -1564,12 +1544,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
15641544 )
15651545 else :
15661546 d [key ] = convert_to_tensor (d [key ], track_meta = get_track_meta (), dtype = torch .float32 )
1567- if get_track_meta ():
1568- if not self .lazy_evaluation :
1569- rot_info = self .pop_transform (d [key ], check = False ) if self ._do_transform else {}
1570- self .push_transform (d [key ], extra_info = rot_info )
1571- elif self ._do_transform :
1572- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
1547+ self .push_transform (d [key ], replace = True )
15731548 return d
15741549
15751550 def inverse (self , data : Mapping [Hashable , torch .Tensor ]) -> dict [Hashable , torch .Tensor ]:
@@ -1744,12 +1719,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
17441719 )
17451720 else :
17461721 d [key ] = convert_to_tensor (d [key ], track_meta = get_track_meta (), dtype = torch .float32 )
1747- if get_track_meta ():
1748- if not self .lazy_evaluation :
1749- xform = self .pop_transform (d [key ], check = False ) if self ._do_transform else {}
1750- self .push_transform (d [key ], extra_info = xform )
1751- elif self ._do_transform :
1752- self .push_transform (d [key ], pending = d [key ].pending_operations .pop ()) # type: ignore
1722+ self .push_transform (d [key ], replace = True )
17531723 return d
17541724
17551725 def inverse (self , data : Mapping [Hashable , torch .Tensor ]) -> dict [Hashable , torch .Tensor ]:
0 commit comments