Skip to content

Commit

Permalink
[DLMED] MILmodel draft PR
Browse files Browse the repository at this point in the history
Signed-off-by: amyronenko <amyronenko@nvidia.com>
Signed-off-by: myron <amyronenko@nvidia.com>
  • Loading branch information
am authored and myron committed Nov 2, 2021
1 parent 3960a51 commit af988b8
Showing 1 changed file with 184 additions and 0 deletions.
184 changes: 184 additions & 0 deletions monai/networks/nets/milmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import os

import torch
import torch.nn as nn
import torchvision.models as models


class MILModel(nn.Module):
"""
A wrapper around backbone classification model suitable for MIL
Args:
num_classes: number of output classes
mil_mode: MIL variant (supported max, mean, att, att_trans, att_trans_pyramid
pretrained: init backbone with pretrained weights. Defaults to True.
backbone: Backbone classifier CNN. Defaults to None, it which case ResNet50 will be used.
backbone_nfeatures: Number of output featues of the backbone CNN (necessary only when using custom backbone)
mil_mode:
mean - average features from all instances, equivalent to pure CNN (non MIL)
max - retain only the instance with the max probability for loss calculation
att - attention based MIL https://arxiv.org/abs/1802.04712
att_trans - transformer MIL
att_trans_pyramid - transformer pyramid MIL
"""

def __init__(self, num_classes, mil_mode="att", pretrained=True, backbone=None, backbone_nfeatures=None):

super().__init__()

self.mil_mode = mil_mode
print("MILModel with mode", mil_mode, "num_classes", num_classes)
n_trans = 4
trans_dropout = 0.0
self.attention = None

if backbone is None:
# use ResNet50 backbone

# os.environ["TORCH_HOME"] = "../../torchhome"
# import torchvision.models as models

net = models.resnet50(pretrained=pretrained)
nfc = net.fc.in_features # save the number of final features
net.fc = torch.nn.Identity() # remove final linear layer

self.extra_outputs = {}
if mil_mode == "att_trans_pyramid":
# register hooks to capture outputs of intermediate layers
def forward_hook(layer_name):
def hook(module, input, output):
self.extra_outputs[layer_name] = output

return hook

self.fhooks = []
net.layer1.register_forward_hook(forward_hook("layer1"))
net.layer2.register_forward_hook(forward_hook("layer2"))
net.layer3.register_forward_hook(forward_hook("layer3"))
net.layer4.register_forward_hook(forward_hook("layer4"))

else:
# use a custom backbone (untested)
net = backbone
nfc = backbone_nfeatures

if backbone_nfeatures is None:
raise ValueError("Number of endencoder features must be provided for a custom backbone model")
if mil_mode not in ["mean", "max", "att", "att_trans", "att_trans2"]:
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))

if self.mil_mode in ["mean", "max"]:
pass

elif self.mil_mode == "att":
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

elif self.mil_mode == "att_trans":
transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)
self.transformer = nn.TransformerEncoder(transformer, num_layers=n_trans)
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

elif self.mil_mode == "att_trans_pyramid":

self.transformer = nn.ModuleList(
[
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans
),
nn.Sequential(
nn.Linear(768, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans
),
),
nn.Sequential(
nn.Linear(1280, 256),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans
),
),
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), num_layers=n_trans
),
]
)
nfc = nfc + 256
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))

else:
raise ValueError("Unsupported mil_mode: " + str(mil_mode))

self.myfc = nn.Linear(nfc, num_classes)
self.net = net

def calc_head(self, x, sh):

x = x.reshape(sh[0], sh[1], -1)

if self.mil_mode == "mean":
x = self.myfc(x)
x = torch.mean(x, dim=1)

elif self.mil_mode == "max":
x = self.myfc(x)
x, _ = torch.max(x, dim=1)

elif self.mil_mode == "att":

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

elif self.mil_mode == "att_trans":

x = x.permute(1, 0, 2)
x = self.transformer(x)
x = x.permute(1, 0, 2)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

elif self.mil_mode == "att_trans_pyramid":

l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)

x = self.transformer[0](l1)
x = self.transformer[1](torch.cat((x, l2), dim=2))
x = self.transformer[2](torch.cat((x, l3), dim=2))
x = self.transformer[3](torch.cat((x, l4), dim=2))

x = x.permute(1, 0, 2)

a = self.attention(x)
a = torch.softmax(a, dim=1)
x = torch.sum(x * a, dim=1)

x = self.myfc(x)

else:
raise ValueError("wrong model mode" + str(self.mil_mode))

return x

def forward(self, x, no_head=False):

sh = x.shape
x = x.reshape(-1, sh[2], sh[3], sh[4])

x = self.net(x)

if not no_head:
x = self.calc_head(x, sh)

return x

0 comments on commit af988b8

Please sign in to comment.