diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 5604d9de1e..113aa505e9 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -33,6 +33,7 @@ import math from collections.abc import Sequence +from typing import Optional import torch from torch import nn @@ -2006,7 +2007,7 @@ def __init__( self.down_blocks.append(down_block) - self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + self.out: Optional[nn.Module] = None def forward( self, @@ -2049,6 +2050,12 @@ def forward( h, _ = downsample_block(hidden_states=h, temb=emb, context=context) h = h.reshape(h.shape[0], -1) + + # 5. out + if self.out is None: + self.out = nn.Sequential( + nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) + ) output: torch.Tensor = self.out(h) return output