Skip to content

Commit

Permalink
5844 improve metatensor slicing (#5845)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>

Fixes #5844

- improves the error message when slicing data with inconsistent
metadata
- fixes an indexing case:
```
x = MetaTensor(np.zeros((10, 3, 4)))
x[slice(1, 0)]  # should return zero length data
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Jan 13, 2023
1 parent 4b464e7 commit f14d50a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
16 changes: 12 additions & 4 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,19 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
# respectively. Don't need to do anything with the metadata.
if batch_idx not in (slice(None, None, None), Ellipsis, None) and idx == 0:
ret_meta = decollate_batch(args[0], detach=False)[batch_idx]
if isinstance(ret_meta, list): # e.g. batch[0:2], re-collate
ret_meta = list_data_collate(ret_meta)
else: # e.g. `batch[0]` or `batch[0, 1]`, batch index is an integer
if isinstance(ret_meta, list) and ret_meta: # e.g. batch[0:2], re-collate
try:
ret_meta = list_data_collate(ret_meta)
except (TypeError, ValueError, RuntimeError, IndexError) as e:
raise ValueError(
"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, "
"please convert it into a torch Tensor using `x.as_tensor()` or "
"a numpy array using `x.array`."
) from e
elif isinstance(ret_meta, MetaObj): # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int
ret_meta.is_batch = False
ret.__dict__ = ret_meta.__dict__.copy()
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th
# dimension.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ def test_indexing(self):
for _d in d:
self.check_meta(_d, data)

def test_slicing(self):
x = MetaTensor(np.zeros((10, 3, 4)))
self.assertEqual(x[slice(4, 1)].shape[0], 0)
x.is_batch = True
with self.assertRaises(ValueError):
x[slice(0, 8)]

@parameterized.expand(DTYPES)
@SkipIfBeforePyTorchVersion((1, 8))
def test_decollate(self, dtype):
Expand Down

0 comments on commit f14d50a

Please sign in to comment.