-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet50.py
155 lines (135 loc) · 4.69 KB
/
resnet50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# %% Imports
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import copy
import pandas as pd
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:54.1"
# Set GPU device
print(torch.cuda.is_available())
device = torch.device("cuda:0")
# %% Load data
TRAIN_ROOT = "D:/research/xai-series-master/xai-series-master/data/brain_mri/training"
TEST_ROOT = "D:/research/xai-series-master/xai-series-master/data/brain_mri/testing"
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_ROOT)
test_dataset = torchvision.datasets.ImageFolder(root=TEST_ROOT)
# %% Building the model
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.resnet50 = models.resnet50(pretrained=True)
num_ftrs = self.resnet50.fc.in_features
self.resnet50.fc = nn.Linear(num_ftrs, 4)
def forward(self, x):
x = self.resnet50(x)
return x
model = CNNModel()
model.to(device)
model
# %% Prepare data for pretrained model
train_dataset = torchvision.datasets.ImageFolder(
root=TRAIN_ROOT,
transform=transforms.Compose([
transforms.Resize((255,255)),
transforms.ToTensor()
])
)
test_dataset = torchvision.datasets.ImageFolder(
root=TEST_ROOT,
transform=transforms.Compose([
transforms.Resize((255,255)),
transforms.ToTensor()
])
)
#train_dataset[0][0].permute(1,2,0)
# %% Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True
)
# %% Train
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
epochs = 5
# Iterate x epochs over the train data
for epoch in range(epochs):
for i, batch in enumerate(train_loader, 0):
inputs, labels = batch
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# Labels are automatically one-hot-encoded
loss = cross_entropy_loss(outputs, labels)
loss.backward()
optimizer.step()
print("This is loss-->",loss)
# %% Inspect predictions for first batch
import pandas as pd
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.numpy()
outputs = model(inputs).max(1).indices.detach().cpu().numpy()
comparison = pd.DataFrame()
print("Batch accuracy: ", (labels==outputs).sum()/len(labels))
comparison["labels"] = labels
comparison["outputs"] = outputs
comparison
# %% Layerwise relevance propagation for ResNet50
def new_layer(layer, g):
"""Clone a layer and pass its parameters through the function g."""
layer = copy.deepcopy(layer)
try: layer.weight = torch.nn.Parameter(g(layer.weight))
except AttributeError: pass
try: layer.bias = torch.nn.Parameter(g(layer.bias))
except AttributeError: pass
return layer
def dense_to_conv(layers):
""" Converts a dense layer to a conv layer """
newlayers = []
for i,layer in enumerate(layers):
if isinstance(layer, nn.Linear):
newlayer = None
if i == 0:
m, n = 3, layer.weight.shape[0]
newlayer = nn.Conv2d(m,n,7, stride=2, padding=3, bias=False)
newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,7,7))
else:
m,n = layer.weight.shape[1],layer.weight.shape[0]
newlayer = nn.Conv2d(m,n,1, bias=False)
newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,1,1))
newlayer.bias = nn.Parameter(layer.bias)
newlayers += [newlayer]
else:
newlayers += [layer]
return newlayers
def get_linear_layers(model):
layers = list(model.children())[:-1] # remove last layer (FC layer)
linear_layers = []
for layer in layers:
if isinstance(layer, nn.Linear):
linear_layers.append(layer)
return lin
def apply_lrp_on_resnet50(model, image):
image = torch.unsqueeze(image, 0)
# >>> Step 1: Extract layers
layers = list(model.resnet50._modules['conv1']) \
+ list(model.resnet50._modules['layer1']) \
+ list(model.resnet50._modules['layer2']) \
+ list(model.resnet50._modules['layer3']) \
+ list(model.resnet50._modules['layer4']) \
+ [model.resnet50._modules['avgpool']] \
+ dense_to_conv