Skip to content

Commit

Permalink
Fix: Small logic mistake in the AsDiscrete.__call__ method (Project…
Browse files Browse the repository at this point in the history
…-MONAI#7984)

Hi MONAI Team! 
Thank you very much for this super nice framework, really appreciate it!
Just found a small logic mistake in one of the transform classes. To
reproduce:
```python
import torch
from monai.transforms.post.array import AsDiscrete

transform = AsDiscrete(argmax=True)
prediction = torch.rand(2, 3, 3)

transform(prediction, argmax=False)
# will still apply argmax
```

### Description

Proposed fix: `argmax` is explicitly checked for `None` in the `__cal__`
method.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: David Carreto Fidalgo <davidc.fidalgo@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
dcfidalgo and KumoLiu authored Aug 3, 2024
1 parent ae5a04d commit 56ee32e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def __call__(
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor)
if argmax or self.argmax:
argmax = self.argmax if argmax is None else argmax
if argmax:
img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))

to_onehot = self.to_onehot if to_onehot is None else to_onehot
Expand Down

0 comments on commit 56ee32e

Please sign in to comment.