Skip to content

Commit

Permalink
MIL Component - MILModel (#3236)
Browse files Browse the repository at this point in the history
* 3251 Add dependency check in WSIReader (#3312)

* [DLMED] add dep check

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix typo

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <nma@nvidia.com>
Signed-off-by: myron <amyronenko@nvidia.com>

* [DLMED] MILmodel  PR

Signed-off-by: myron <amyronenko@nvidia.com>

* small updates

Signed-off-by: myron <amyronenko@nvidia.com>

* fix jit issues

Signed-off-by: myron <amyronenko@nvidia.com>

* jit fix attempt

Signed-off-by: myron <amyronenko@nvidia.com>

* removing Enum

Signed-off-by: myron <amyronenko@nvidia.com>

Co-authored-by: Nic Ma <nma@nvidia.com>
Co-authored-by: Behrooz <3968947+drbeh@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 16, 2021
1 parent 9ec1d14 commit 45d4a61
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 0 deletions.
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .milmodel import MILModel
from .netadapter import NetAdapter
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
Expand Down
229 changes: 229 additions & 0 deletions monai/networks/nets/milmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from typing import Dict, Optional, Union, cast

import torch
import torch.nn as nn

from monai.utils.module import optional_import

models, _ = optional_import("torchvision.models")


class MILModel(nn.Module):
"""
Multiple Instance Learning (MIL) model, with a backbone classification model.
Args:
num_classes: number of output classes
mil_mode: MIL algorithm (either mean, max, att, att_trans, att_trans_pyramid)
Defaults to ``att``.
pretrained: init backbone with pretrained weights.
Defaults to ``True``.
backbone: Backbone classifier CNN. (either None, nn.Module that returns features,
or a string name of a torchvision model)
Defaults to ``None``, in which case ResNet50 is used.
backbone_num_features: Number of output features of the backbone CNN
Defaults to ``None`` (necessary only when using a custom backbone)
mil_mode:
"mean" - average features from all instances, equivalent to pure CNN (non MIL)
"max - retain only the instance with the max probability for loss calculation
"att" - attention based MIL https://arxiv.org/abs/1802.04712
"att_trans" - transformer MIL https://arxiv.org/abs/2111.01556
"att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556
"""

def __init__(
self,
num_classes: int,
mil_mode: str = "att",
pretrained: bool = True,
backbone: Optional[Union[str, nn.Module]] = None,
backbone_num_features: Optional[int] = None,
trans_blocks: int = 4,
trans_dropout: float = 0.0,
) -> None:

super().__init__()

if num_classes <= 0:
raise ValueError("Number of classes must be positive: " + str(num_classes))

if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))

self.mil_mode = mil_mode.lower()
print("MILModel with mode", mil_mode, "num_classes", num_classes)
self.attention = nn.Sequential()
self.transformer = None # type: Optional[nn.Module]

if backbone is None:

net = models.resnet50(pretrained=pretrained)
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # remove final linear layer

self.extra_outputs = {} # type: Dict[str, torch.Tensor]

if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
def forward_hook(layer_name):
def hook(module, input, output):
self.extra_outputs[layer_name] = output

return hook

net.layer1.register_forward_hook(forward_hook("layer1"))
net.layer2.register_forward_hook(forward_hook("layer2"))
net.layer3.register_forward_hook(forward_hook("layer3"))
net.layer4.register_forward_hook(forward_hook("layer4"))

elif isinstance(backbone, str):

# assume torchvision model string is provided
torch_model = getattr(models, backbone, None)
if torch_model is None:
raise ValueError("Unknown torch vision model" + str(backbone))
net = torch_model(pretrained=pretrained)

if getattr(net, "fc", None) is not None:
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # remove final linear layer
else:
raise ValueError(
"Unable to detect FC layer for the torchvision model " + str(backbone),
". Please initialize the backbone model manually.",
)

elif isinstance(backbone, nn.Module):
# use a custom backbone
net = backbone
nfc = backbone_num_features

if backbone_num_features is None:
raise ValueError("Number of endencoder features must be provided for a custom backbone model")

else:
raise ValueError("Unsupported backbone")

if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))

if self.mil_mode in ["mean", "max"]:
pass
elif self.mil_mode == "att":
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

elif self.mil_mode == "att_trans":
transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)
self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

elif self.mil_mode == "att_trans_pyramid":

transformer_list = nn.ModuleList(
[
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
),
nn.Sequential(
nn.Linear(768, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.Sequential(
nn.Linear(1280, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout),
num_layers=trans_blocks,
),
]
)
self.transformer = transformer_list
nfc = nfc + 256
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

else:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))

self.myfc = nn.Linear(nfc, num_classes)
self.net = net

def calc_head(self, x: torch.Tensor) -> torch.Tensor:

sh = x.shape

if self.mil_mode == "mean":
x = self.myfc(x)
x = torch.mean(x, dim=1)

elif self.mil_mode == "max":
x = self.myfc(x)
x, _ = torch.max(x, dim=1)

elif self.mil_mode == "att":

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

elif self.mil_mode == "att_trans" and self.transformer is not None:

x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:

l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)

transformer_list = cast(nn.ModuleList, self.transformer)

x = transformer_list[0](l1)
x = transformer_list[1](torch.cat((x, l2), dim=2))
x = transformer_list[2](torch.cat((x, l3), dim=2))
x = transformer_list[3](torch.cat((x, l4), dim=2))

x = x.permute(1, 0, 2)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

else:
raise ValueError("Wrong model mode" + str(self.mil_mode))

return x

def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:

sh = x.shape
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])

x = self.net(x)
x = x.reshape(sh[0], sh[1], -1)

if not no_head:
x = self.calc_head(x)

return x
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def run_testsuit():
"test_load_imaged",
"test_load_spacing_orientation",
"test_mednistdataset",
"test_milmodel",
"test_mlp",
"test_nifti_header_revise",
"test_nifti_rw",
Expand Down
91 changes: 91 additions & 0 deletions tests/test_milmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import MILModel
from monai.utils.module import optional_import
from tests.utils import test_script_save

models, _ = optional_import("torchvision.models")

device = "cuda" if torch.cuda.is_available() else "cpu"


TEST_CASE_MILMODEL = []
for num_classes in [1, 5]:
for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
test_case = [
{"num_classes": num_classes, "mil_mode": mil_mode, "pretrained": False},
(1, 2, 3, 512, 512),
(1, num_classes),
]
TEST_CASE_MILMODEL.append(test_case)


for trans_blocks in [1, 3]:
test_case = [
{"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5},
(1, 2, 3, 512, 512),
(1, 5),
]
TEST_CASE_MILMODEL.append(test_case)

# torchvision backbone
TEST_CASE_MILMODEL.append(
[{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)]
)
TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)])

# custom backbone
backbone = models.densenet121(pretrained=False)
backbone_nfeatures = backbone.classifier.in_features
backbone.classifier = torch.nn.Identity()
TEST_CASE_MILMODEL.append(
[
{"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False},
(2, 2, 3, 512, 512),
(2, 5),
]
)


class TestMilModel(unittest.TestCase):
@parameterized.expand(TEST_CASE_MILMODEL)
def test_shape(self, input_param, input_shape, expected_shape):
net = MILModel(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape, dtype=torch.float).to(device))
self.assertEqual(result.shape, expected_shape)

def test_ill_args(self):
with self.assertRaises(ValueError):
MILModel(
num_classes=5,
pretrained=False,
backbone="resnet50",
backbone_num_features=2048,
mil_mode="att_trans_pyramid",
)

def test_script(self):
input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0]
net = MILModel(**input_param)
test_data = torch.randn(input_shape, dtype=torch.float)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()

0 comments on commit 45d4a61

Please sign in to comment.