Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vnet): convert dropout_prob to a tuple #6768

Merged
merged 13 commits into from
Jul 27, 2023
36 changes: 26 additions & 10 deletions monai/networks/nets/vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks.blocks.convolutions import Convolution
from monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args
from monai.utils import deprecated_arg

__all__ = ["VNet"]

Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(
out_channels: int,
nconvs: int,
act: tuple[str, dict] | str,
dropout_prob: float | None = None,
dropout_prob: tuple[float | None, float] = (None, 0.5),
dropout_dim: int = 3,
):
super().__init__()
Expand All @@ -144,8 +145,8 @@ def __init__(

self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)
self.bn1 = norm_type(out_channels // 2)
self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None
self.dropout2 = dropout_type(0.5)
self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None
self.dropout2 = dropout_type(dropout_prob[1])
self.act_function1 = get_acti_layer(act, out_channels // 2)
self.act_function2 = get_acti_layer(act, out_channels)
self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)
Expand Down Expand Up @@ -206,8 +207,9 @@ class VNet(nn.Module):
The value should meet the condition that ``16 % in_channels == 0``.
out_channels: number of output channels for the network. Defaults to 1.
act: activation type in the network. Defaults to ``("elu", {"inplace": True})``.
dropout_prob: dropout ratio. Defaults to 0.5.
dropout_dim: determine the dimensions of dropout. Defaults to 3.
dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.
dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).
dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).

- ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.
- ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).
Expand All @@ -216,15 +218,29 @@ class VNet(nn.Module):
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.

.. deprecated:: 1.2
``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.

"""

@deprecated_arg(
name="dropout_prob",
since="1.2",
new_name="dropout_prob_down",
msg_suffix="please use `dropout_prob_down` instead.",
)
@deprecated_arg(
name="dropout_prob", since="1.2", new_name="dropout_prob_up", msg_suffix="please use `dropout_prob_up` instead."
)
def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 1,
act: tuple[str, dict] | str = ("elu", {"inplace": True}),
dropout_prob: float = 0.5,
SauravMaheshkar marked this conversation as resolved.
Show resolved Hide resolved
dropout_prob: float | None = 0.5, # deprecated
dropout_prob_down: float | None = 0.5,
dropout_prob_up: tuple[float | None, float] = (0.5, 0.5),
dropout_dim: int = 3,
bias: bool = False,
):
Expand All @@ -236,10 +252,10 @@ def __init__(
self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias)
self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias)
self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob)
self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias)
self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias)
self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up)
self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up)
self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act)
self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act)
self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)
Expand Down
Loading