From 4c2e3413a01b783ccd3be111e30ebb718c99d5e4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Jan 2023 22:18:33 +0000 Subject: [PATCH] fixes #5509 Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 7 ++++--- monai/transforms/inverse.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 22f95027086..0dd27872a96 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -523,10 +523,11 @@ def ensure_torch_and_prune_meta( By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. """ - img = convert_to_tensor(im) # potentially ascontiguousarray + tracking_meta = get_track_meta() and meta is not None + img = convert_to_tensor(im, track_meta=tracking_meta) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` - if not get_track_meta() or meta is None: + if not tracking_meta: return img # remove any superfluous metadata. @@ -540,7 +541,7 @@ def ensure_torch_and_prune_meta( meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta) # return the `MetaTensor` - return MetaTensor(img, meta=meta) + return img.copy_meta_from(meta) def __repr__(self): """ diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 18c22c82fa9..c741786e0b0 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -170,6 +170,8 @@ def track_transform_tensor( if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)): if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t return data return out_obj # return with data_t as tensor if get_track_meta() is False @@ -202,15 +204,14 @@ def track_transform_tensor( else: out_obj.push_applied_operation(info) if isinstance(data, Mapping): + if not isinstance(data, dict): + data = dict(data) if isinstance(data_t, MetaTensor): data[key] = data_t.copy_meta_from(out_obj) else: - # If this is the first, create list x_k = TraceableTransform.trace_key(key) if x_k not in data: - if not isinstance(data, dict): - data = dict(data) - data[x_k] = [] + data[x_k] = [] # If this is the first, create list data[x_k].append(info) return data return out_obj