@@ -1224,7 +1224,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
12241224 out : torch .Tensor = xform (img_t .unsqueeze (0 ), transform_t , spatial_size = sp_size ).float ().squeeze (0 )
12251225 out = convert_to_dst_type (out , dst = data , dtype = out .dtype )[0 ]
12261226 if isinstance (data , MetaTensor ):
1227- self .update_meta (out , transform_t )
1227+ out . affine @= self .update_meta (out , transform_t ) # type: ignore
12281228 return out
12291229
12301230
@@ -2756,7 +2756,7 @@ def __call__(
27562756 "do_resampling" : do_resampling ,
27572757 },
27582758 )
2759- out .affine = self .update_meta (out , mat , img .shape [1 :], sp_size ) # type: ignore
2759+ out .affine @ = self .update_meta (out , mat , img .shape [1 :], sp_size ) # type: ignore
27602760 return out
27612761
27622762 def lazy_call (self , img , affine , output_size , mode , padding_mode , do_resampling ) -> torch .Tensor :
@@ -2798,7 +2798,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
27982798 if not isinstance (out , MetaTensor ):
27992799 out = MetaTensor (out )
28002800 out .meta = data .meta # type: ignore
2801- self .update_meta (out , inv_affine , data .shape [1 :], orig_size )
2801+ out . affine @= self .update_meta (out , inv_affine , data .shape [1 :], orig_size )
28022802 return out
28032803
28042804
0 commit comments