Skip to content

Commit

Permalink
Ensure ensemble preds are on the same device (#6368)
Browse files Browse the repository at this point in the history
Fixes # #6361

Ensures ensemble preds are on the same device

---------

Signed-off-by: myron <amyronenko@nvidia.com>
  • Loading branch information
myron authored Apr 15, 2023
1 parent c5b1127 commit 888ad2f
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def ensemble_pred(self, preds, sigmoid=False):
a tensor which is the ensembled prediction.
"""

if any(not p.is_cuda for p in preds):
preds = [p.cpu() for p in preds] # ensure CPU if at least one is on CPU

if self.mode == "mean":
prob = MeanEnsemble()(preds)
return prob2class(cast(torch.Tensor, prob), dim=0, keepdim=True, sigmoid=sigmoid)
Expand Down

0 comments on commit 888ad2f

Please sign in to comment.