Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

with eval_mode(net): #1384

Merged
merged 5 commits into from
Dec 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monai.engines.utils import default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode
from monai.transforms import Transform
from monai.utils import ensure_tuple, exact_version, optional_import

Expand Down Expand Up @@ -190,8 +191,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
inputs, targets, args, kwargs = batch

# execute forward computation
self.network.eval()
with torch.no_grad():
with eval_mode(self.network):
if self.amp:
with torch.cuda.amp.autocast():
predictions = self.inferer(inputs, self.network, *args, **kwargs)
Expand Down Expand Up @@ -298,8 +298,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict
# execute forward computation
predictions = {Keys.IMAGE: inputs, Keys.LABEL: targets}
for idx, network in enumerate(self.networks):
network.eval()
with torch.no_grad():
with eval_mode(network):
if self.amp:
with torch.cuda.amp.autocast():
predictions.update({self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)})
Expand Down
35 changes: 35 additions & 0 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""

import warnings
from contextlib import contextmanager
from typing import Any, Callable, Optional, Sequence, cast

import torch
Expand All @@ -29,6 +30,7 @@
"normal_init",
"icnr_init",
"pixelshuffle",
"eval_mode",
]


Expand Down Expand Up @@ -241,3 +243,36 @@ def pixelshuffle(x: torch.Tensor, dimensions: int, scale_factor: int) -> torch.T
x = x.reshape(batch_size, org_channels, *([factor] * dim + input_size[2:]))
x = x.permute(permute_indices).reshape(output_size)
return x


@contextmanager
def eval_mode(*nets: nn.Module):
"""
Set network(s) to eval mode and then return to original state at the end.

Args:
nets: Input network(s)

Examples

.. code-block:: python

t=torch.rand(1,1,16,16)
p=torch.nn.Conv2d(1,1,3)
print(p.training) # True
with eval_mode(p):
print(p.training) # False
print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated
"""

# Get original state of network(s)
training = [n for n in nets if n.training]

try:
# set to eval mode
with torch.no_grad():
wyli marked this conversation as resolved.
Show resolved Hide resolved
yield [n.eval() for n in nets]
finally:
# Return required networks to training
for n in training:
n.train()
24 changes: 9 additions & 15 deletions tests/test_ahnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks import FCN, MCFCN
from monai.networks.nets import AHNet
from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save
Expand Down Expand Up @@ -127,8 +128,7 @@ class TestFCN(unittest.TestCase):
@skip_if_quick
def test_fcn_shape(self, input_param, input_shape, expected_shape):
net = FCN(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -138,8 +138,7 @@ class TestFCNWithPretrain(unittest.TestCase):
@skip_if_quick
def test_fcn_shape(self, input_param, input_shape, expected_shape):
net = test_pretrained_networks(FCN, input_param, device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -148,8 +147,7 @@ class TestMCFCN(unittest.TestCase):
@parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3])
def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
net = MCFCN(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -158,8 +156,7 @@ class TestMCFCNWithPretrain(unittest.TestCase):
@parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2])
def test_mcfcn_shape(self, input_param, input_shape, expected_shape):
net = test_pretrained_networks(MCFCN, input_param, device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -174,8 +171,7 @@ class TestAHNET(unittest.TestCase):
)
def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape):
net = AHNet(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -189,8 +185,7 @@ def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape):
@skip_if_quick
def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape):
net = AHNet(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -213,8 +208,7 @@ def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_p
net = AHNet(**input_param).to(device)
net2d = FCN(**fcn_input_param).to(device)
net.copy_from(net2d)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -230,7 +224,7 @@ def test_initialize_pretrained(self):
progress=True,
).to(device)
input_data = torch.randn(2, 2, 32, 32, 64).to(device)
with torch.no_grad():
with eval_mode(net):
result = net.forward(input_data)
self.assertEqual(result.shape, (2, 3, 32, 32, 64))

Expand Down
4 changes: 2 additions & 2 deletions tests/test_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.layers import Act
from monai.networks.nets import AutoEncoder
from tests.utils import test_script_save
Expand Down Expand Up @@ -75,8 +76,7 @@ class TestAutoEncoder(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = AutoEncoder(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_basic_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import BasicUNet
from tests.utils import test_script_save

Expand Down Expand Up @@ -95,8 +96,7 @@ def test_shape(self, input_param, input_shape, expected_shape):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(input_param)
net = BasicUNet(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.layers import ChannelPad

TEST_CASES_3D = []
Expand All @@ -34,8 +35,7 @@ class TestChannelPad(unittest.TestCase):
@parameterized.expand(TEST_CASES_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = ChannelPad(**input_param)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(list(result.shape), list(expected_shape))

Expand Down
4 changes: 2 additions & 2 deletions tests/test_copy_itemsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.transforms import CopyItemsd
from monai.utils import ensure_tuple

Expand Down Expand Up @@ -61,8 +62,7 @@ def test_array_values(self):
def test_graph_tensor_values(self):
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0")
net = torch.nn.PReLU().to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
pred = net(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device))
input_data = {"pred": pred, "seg": torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)}
result = CopyItemsd(keys="pred", times=1, names="pred_1")(input_data)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import densenet121, densenet169, densenet201, densenet264
from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save

Expand Down Expand Up @@ -66,8 +67,7 @@ class TestPretrainedDENSENET(unittest.TestCase):
@skip_if_quick
def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape):
net = test_pretrained_networks(model, input_param, device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -76,8 +76,7 @@ class TestDENSENET(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_densenet_shape(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import Discriminator
from tests.utils import test_script_save

Expand Down Expand Up @@ -42,8 +43,7 @@ class TestDiscriminator(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_data, expected_shape):
net = Discriminator(**input_param)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(input_data)
self.assertEqual(result.shape, expected_shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_downsample_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks import MaxAvgPool

TEST_CASES = [
Expand Down Expand Up @@ -41,8 +42,7 @@ class TestMaxAvgPool(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = MaxAvgPool(**input_param)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import DynUNet
from tests.utils import test_script_save

Expand Down Expand Up @@ -107,8 +108,7 @@ class TestDynUNet(unittest.TestCase):
@parameterized.expand(TEST_CASE_DYNUNET_2D + TEST_CASE_DYNUNET_3D)
def test_shape(self, input_param, input_shape, expected_shape):
net = DynUNet(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result[0].shape, expected_shape)

Expand Down
7 changes: 3 additions & 4 deletions tests/test_dynunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding
from tests.utils import test_script_save

Expand Down Expand Up @@ -70,8 +71,7 @@ class TestResBasicBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_RES_BASIC_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape):
for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]:
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

Expand All @@ -94,8 +94,7 @@ class TestUpBlock(unittest.TestCase):
@parameterized.expand(TEST_UP_BLOCK)
def test_shape(self, input_param, input_shape, expected_shape, skip_shape):
net = UnetUpBlock(**input_param)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net(torch.randn(input_shape), torch.randn(skip_shape))
self.assertEqual(result.shape, expected_shape)

Expand Down
31 changes: 31 additions & 0 deletions tests/test_eval_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from monai.networks.utils import eval_mode


class TestEvalMode(unittest.TestCase):
def test_eval_mode(self):
t = torch.rand(1, 1, 4, 4)
p = torch.nn.Conv2d(1, 1, 3)
self.assertTrue(p.training) # True
with eval_mode(p):
self.assertFalse(p.training) # False
with self.assertRaises(RuntimeError):
p(t).sum().backward()


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/test_fullyconnectednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import FullyConnectedNet, VarFullyConnectedNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -45,8 +46,7 @@ def test_fc_shape(self, dropout):
@parameterized.expand(VFC_CASES)
def test_vfc_shape(self, input_param, input_shape, expected_shape):
net = VarFullyConnectedNet(**input_param).to(device)
net.eval()
with torch.no_grad():
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))[0]
self.assertEqual(result.shape, expected_shape)

Expand Down
Loading