Skip to content
Merged
36 changes: 26 additions & 10 deletions monai/networks/nets/vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args
from monai.utils import deprecated_arg

__all__ = ["VNet"]

Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(
out_channels: int,
nconvs: int,
act: tuple[str, dict] | str,
dropout_prob: float | None = None,
dropout_prob: tuple[float | None, float] = (None, 0.5),
dropout_dim: int = 3,
):
super().__init__()
Expand All @@ -144,8 +145,8 @@ def __init__(

self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)
self.bn1 = norm_type(out_channels // 2)
self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None
self.dropout2 = dropout_type(0.5)
self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None
self.dropout2 = dropout_type(dropout_prob[1])
self.act_function1 = get_acti_layer(act, out_channels // 2)
self.act_function2 = get_acti_layer(act, out_channels)
self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)
Expand Down Expand Up @@ -206,8 +207,9 @@ class VNet(nn.Module):
The value should meet the condition that ``16 % in_channels == 0``.
out_channels: number of output channels for the network. Defaults to 1.
act: activation type in the network. Defaults to ``("elu", {"inplace": True})``.
dropout_prob: dropout ratio. Defaults to 0.5.
dropout_dim: determine the dimensions of dropout. Defaults to 3.
dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.
dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).
dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).

- ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.
- ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).
Expand All @@ -216,15 +218,29 @@ class VNet(nn.Module):
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.

.. deprecated:: 1.2
``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.

"""

@deprecated_arg(
name="dropout_prob",
since="1.2",
new_name="dropout_prob_down",
msg_suffix="please use `dropout_prob_down` instead.",
)
@deprecated_arg(
name="dropout_prob", since="1.2", new_name="dropout_prob_up", msg_suffix="please use `dropout_prob_up` instead."
)
def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 1,
act: tuple[str, dict] | str = ("elu", {"inplace": True}),
dropout_prob: float = 0.5,
dropout_prob: float | None = 0.5, # deprecated
dropout_prob_down: float | None = 0.5,
dropout_prob_up: tuple[float | None, float] = (0.5, 0.5),
dropout_dim: int = 3,
bias: bool = False,
):
Expand All @@ -236,10 +252,10 @@ def __init__(
self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias)
self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias)
self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up)
self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act)
self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act)
self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)
Expand Down