Skip to content

Commit 60e4074

Browse files
committed
fixes gpu tests
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 3ef6554 commit 60e4074

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

monai/networks/nets/dints.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor):
153153
out: weighted average of the operation results.
154154
"""
155155
out = 0.0
156-
for idx, _ in enumerate(self.ops):
157-
out = out + _(x) * weight[idx]
156+
weight = weight.to(x)
157+
for idx, _op in enumerate(self.ops):
158+
out = out + _op(x) * weight[idx]
158159
return out
159160

160161

0 commit comments

Comments
 (0)