-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
host/load pretrained weights for 3D resnet #48
Comments
see also Project-MONAI/MONAI#271 |
now we have an option of the project share drive for this Project-MONAI/MONAI#2489, do you want to revisit the PR? @Douwe-Spaanderman |
sorry but i didnot find a solution for using pretrained resnet @wyli |
I have downloaded the weights and loading with state dict from monai.networks.nets import resnet10, resnet18, resnet34, resnet50
PATH_PRETRAINED_WEIGHTS = "/home/jovyan/work/pretrained/resnet_10_23dataset.pth"
net = resnet10(
pretrained=False,
spatial_dims=3,
)
net.load_state_dict(torch.load(PATH_PRETRAINED_WEIGHTS)) but still getting the following error:
|
@Borda I found a solution from https://github.com/Tencent/MedicalNet. This code works. import torch
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
__all__ = [
'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnet200'
]
def generate_model(opt):
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
if opt.model_depth == 10:
model = resnet10(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 18:
model = resnet18(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 34:
model = resnet34(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 50:
model = resnet50(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 101:
model = resnet101(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 152:
model = resnet152(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
elif opt.model_depth == 200:
model = resnet200(
# sample_input_W=opt.input_W,
# sample_input_H=opt.input_H,
# sample_input_D=opt.input_D,
shortcut_type=opt.resnet_shortcut)
net_dict = model.state_dict()
# load pretrain
if opt.pretrain_path:
print('loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path)
pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
net_dict.update(pretrain_dict)
model.load_state_dict(net_dict)
# new_parameters = []
# for pname, p in model.named_parameters():
# for layer_name in opt.new_layer_names:
# if pname.find(layer_name) >= 0:
# new_parameters.append(p)
# break
#
# new_parameters_id = list(map(id, new_parameters))
# base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
# parameters = {'base_parameters': base_parameters,
# 'new_parameters': new_parameters}
# return model, parameters
return model
return model, model.parameters()
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
# 3x3x3 convolution with padding
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
dilation=dilation,
stride=stride,
padding=dilation,
bias=False)
def downsample_basic_block(x, planes, stride, no_cuda=False):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if not no_cuda:
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
# sample_input_D,
# sample_input_H,
# sample_input_W,
num_seg_classes=2,
shortcut_type='B',
no_cuda=False):
self.inplanes = 64
self.no_cuda = no_cuda
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(
1,
64,
kernel_size=7,
stride=(2, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(
block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(
block, 256, layers[2], shortcut_type, stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], shortcut_type, stride=1, dilation=4)
self.conv_seg = nn.Sequential(
nn.ConvTranspose3d(
512 * block.expansion,
32,
2,
stride=2
),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(
32,
32,
kernel_size=3,
stride=(1, 1, 1),
padding=(1, 1, 1),
bias=False),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(
32,
num_seg_classes,
kernel_size=1,
stride=(1, 1, 1),
bias=False)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
no_cuda=self.no_cuda)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_seg(x)
return x
def resnet10(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
return model
def resnet18(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34(**kwargs):
"""Constructs a ResNet-34 model.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
def resnet200(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
return model |
nice, just wondering how it differs from the MONAI implementation... :) |
@JianJuly copy-pasted your suggested code and getting almost the same error:
for loading as: PATH_PRETRAINED_WEIGHTS = "/home/jirka/Downloads/pretrain/resnet_50_23dataset.pth"
net = resnet50()
state_dict = torch.load(PATH_PRETRAINED_WEIGHTS)
net.load_state_dict(state_dict)
print(net) |
I could load most parameters from the state dictionary excluding
Running this code throws the following error. Only the bias of the 4 layers and the weights of the FC layer are missing.
We don't need weights of the FC layer if we are finetuning on a different task.
This successfully loads the statedict to the MONAI ResNet implementation. |
Yes, that what did in the end too, the rename is critical... |
Update: Couldn't we PR this and at least input this Solution . . . its better than getting Error though!!! Thanks Again |
In case anyone else finds this years later like me, you have to use the appropriate parameters when initializing resnet as described here in the NotImplementedException. Specifically, I used
which loads the model "resnet_50_23dataset.pth" from Med3D without issue for me after doing the same preprocessing of the dictionary keys as others did here. |
model = resnet101( ) |
Is your feature request related to a problem? Please describe.
PR Project-MONAI/MONAI#2253 implements a generic version of resnet for spatial 1/2/3D inputs. It'd be very useful for MONAI to provide a further
pretrained=True
option for the model initialisations. However there is currently some practical issue Project-MONAI/MONAI#2253 (comment):cc @Douwe-Spaanderman
The text was updated successfully, but these errors were encountered: