Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Jan 30, 2023
1 parent 696e411 commit 4c2e341
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,11 @@ def ensure_torch_and_prune_meta(
By default, a `MetaTensor` is returned.
However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned.
"""
img = convert_to_tensor(im) # potentially ascontiguousarray
tracking_meta = get_track_meta() and meta is not None
img = convert_to_tensor(im, track_meta=tracking_meta) # potentially ascontiguousarray

# if not tracking metadata, return `torch.Tensor`
if not get_track_meta() or meta is None:
if not tracking_meta:
return img

# remove any superfluous metadata.
Expand All @@ -540,7 +541,7 @@ def ensure_torch_and_prune_meta(
meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)

# return the `MetaTensor`
return MetaTensor(img, meta=meta)
return img.copy_meta_from(meta)

def __repr__(self):
"""
Expand Down
9 changes: 5 additions & 4 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def track_transform_tensor(

if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t
return data
return out_obj # return with data_t as tensor if get_track_meta() is False
Expand Down Expand Up @@ -202,15 +204,14 @@ def track_transform_tensor(
else:
out_obj.push_applied_operation(info)
if isinstance(data, Mapping):
if not isinstance(data, dict):
data = dict(data)
if isinstance(data_t, MetaTensor):
data[key] = data_t.copy_meta_from(out_obj)
else:
# If this is the first, create list
x_k = TraceableTransform.trace_key(key)
if x_k not in data:
if not isinstance(data, dict):
data = dict(data)
data[x_k] = []
data[x_k] = [] # If this is the first, create list
data[x_k].append(info)
return data
return out_obj
Expand Down

0 comments on commit 4c2e341

Please sign in to comment.