Skip to content

Commit 9122739

Browse files
Removing instantiation in forward (#8580)
### Description This is a minor fix for the previous PR on this class. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cf5790d commit 9122739

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import math
3535
from collections.abc import Sequence
3636
from functools import reduce
37-
from typing import Optional
3837

3938
import numpy as np
4039
import torch
@@ -2016,7 +2015,7 @@ def __init__(
20162015

20172016
last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1])
20182017

2019-
self.out: Optional[nn.Module] = nn.Sequential(
2018+
self.out: nn.Module = nn.Sequential(
20202019
nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
20212020
)
20222021

@@ -2063,9 +2062,6 @@ def forward(
20632062
h = h.reshape(h.shape[0], -1)
20642063

20652064
# 5. out
2066-
self.out = nn.Sequential(
2067-
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2068-
)
20692065
output: torch.Tensor = self.out(h)
20702066

20712067
return output

0 commit comments

Comments
 (0)