-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
32 lines (29 loc) · 1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
def define_model(name ='resnet32'):
if name=='resnet18':
model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
if name=='resnet32':
model = models.resnet34(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
if name=='densenet':
model = models.densenet161(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 2)
if name=='resneXt':
from RESNEXT.model import ResNeXt101
model = ResNeXt101().net
return model
if __name__=='__main__':
# a = define_model(name ='densenet')
a = define_model(name ='resneXt')
print(a)