-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: amyronenko <amyronenko@nvidia.com> Signed-off-by: myron <amyronenko@nvidia.com>
- Loading branch information
Showing
1 changed file
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import os | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torchvision.models as models | ||
|
||
|
||
class MILModel(nn.Module): | ||
""" | ||
A wrapper around backbone classification model suitable for MIL | ||
Args: | ||
num_classes: number of output classes | ||
mil_mode: MIL variant (supported max, mean, att, att_trans, att_trans_pyramid | ||
pretrained: init backbone with pretrained weights. Defaults to True. | ||
backbone: Backbone classifier CNN. Defaults to None, it which case ResNet50 will be used. | ||
backbone_nfeatures: Number of output featues of the backbone CNN (necessary only when using custom backbone) | ||
mil_mode: | ||
mean - average features from all instances, equivalent to pure CNN (non MIL) | ||
max - retain only the instance with the max probability for loss calculation | ||
att - attention based MIL https://arxiv.org/abs/1802.04712 | ||
att_trans - transformer MIL | ||
att_trans_pyramid - transformer pyramid MIL | ||
""" | ||
|
||
def __init__(self, num_classes, mil_mode="att", pretrained=True, backbone=None, backbone_nfeatures=None): | ||
|
||
super().__init__() | ||
|
||
self.mil_mode = mil_mode | ||
print("MILModel with mode", mil_mode, "num_classes", num_classes) | ||
n_trans = 4 | ||
trans_dropout = 0.0 | ||
self.attention = None | ||
|
||
if backbone is None: | ||
# use ResNet50 backbone | ||
|
||
# os.environ["TORCH_HOME"] = "../../torchhome" | ||
# import torchvision.models as models | ||
|
||
net = models.resnet50(pretrained=pretrained) | ||
nfc = net.fc.in_features # save the number of final features | ||
net.fc = torch.nn.Identity() # remove final linear layer | ||
|
||
self.extra_outputs = {} | ||
if mil_mode == "att_trans_pyramid": | ||
# register hooks to capture outputs of intermediate layers | ||
def forward_hook(layer_name): | ||
def hook(module, input, output): | ||
self.extra_outputs[layer_name] = output | ||
|
||
return hook | ||
|
||
self.fhooks = [] | ||
net.layer1.register_forward_hook(forward_hook("layer1")) | ||
net.layer2.register_forward_hook(forward_hook("layer2")) | ||
net.layer3.register_forward_hook(forward_hook("layer3")) | ||
net.layer4.register_forward_hook(forward_hook("layer4")) | ||
|
||
else: | ||
# use a custom backbone (untested) | ||
net = backbone | ||
nfc = backbone_nfeatures | ||
|
||
if backbone_nfeatures is None: | ||
raise ValueError("Number of endencoder features must be provided for a custom backbone model") | ||
if mil_mode not in ["mean", "max", "att", "att_trans", "att_trans2"]: | ||
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode)) | ||
|
||
if self.mil_mode in ["mean", "max"]: | ||
pass | ||
|
||
elif self.mil_mode == "att": | ||
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
||
elif self.mil_mode == "att_trans": | ||
transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout) | ||
self.transformer = nn.TransformerEncoder(transformer, num_layers=n_trans) | ||
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
||
elif self.mil_mode == "att_trans_pyramid": | ||
|
||
self.transformer = nn.ModuleList( | ||
[ | ||
nn.TransformerEncoder( | ||
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans | ||
), | ||
nn.Sequential( | ||
nn.Linear(768, 256), | ||
nn.TransformerEncoder( | ||
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans | ||
), | ||
), | ||
nn.Sequential( | ||
nn.Linear(1280, 256), | ||
nn.TransformerEncoder( | ||
nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=n_trans | ||
), | ||
), | ||
nn.TransformerEncoder( | ||
nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout), num_layers=n_trans | ||
), | ||
] | ||
) | ||
nfc = nfc + 256 | ||
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1)) | ||
|
||
else: | ||
raise ValueError("Unsupported mil_mode: " + str(mil_mode)) | ||
|
||
self.myfc = nn.Linear(nfc, num_classes) | ||
self.net = net | ||
|
||
def calc_head(self, x, sh): | ||
|
||
x = x.reshape(sh[0], sh[1], -1) | ||
|
||
if self.mil_mode == "mean": | ||
x = self.myfc(x) | ||
x = torch.mean(x, dim=1) | ||
|
||
elif self.mil_mode == "max": | ||
x = self.myfc(x) | ||
x, _ = torch.max(x, dim=1) | ||
|
||
elif self.mil_mode == "att": | ||
|
||
a = self.attention(x) | ||
a = torch.softmax(a, dim=1) | ||
x = torch.sum(x * a, dim=1) | ||
|
||
x = self.myfc(x) | ||
|
||
elif self.mil_mode == "att_trans": | ||
|
||
x = x.permute(1, 0, 2) | ||
x = self.transformer(x) | ||
x = x.permute(1, 0, 2) | ||
|
||
a = self.attention(x) | ||
a = torch.softmax(a, dim=1) | ||
x = torch.sum(x * a, dim=1) | ||
|
||
x = self.myfc(x) | ||
|
||
elif self.mil_mode == "att_trans_pyramid": | ||
|
||
l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2) | ||
|
||
x = self.transformer[0](l1) | ||
x = self.transformer[1](torch.cat((x, l2), dim=2)) | ||
x = self.transformer[2](torch.cat((x, l3), dim=2)) | ||
x = self.transformer[3](torch.cat((x, l4), dim=2)) | ||
|
||
x = x.permute(1, 0, 2) | ||
|
||
a = self.attention(x) | ||
a = torch.softmax(a, dim=1) | ||
x = torch.sum(x * a, dim=1) | ||
|
||
x = self.myfc(x) | ||
|
||
else: | ||
raise ValueError("wrong model mode" + str(self.mil_mode)) | ||
|
||
return x | ||
|
||
def forward(self, x, no_head=False): | ||
|
||
sh = x.shape | ||
x = x.reshape(-1, sh[2], sh[3], sh[4]) | ||
|
||
x = self.net(x) | ||
|
||
if not no_head: | ||
x = self.calc_head(x, sh) | ||
|
||
return x |