-
Notifications
You must be signed in to change notification settings - Fork 16
/
basic_net.py
111 lines (83 loc) · 3.09 KB
/
basic_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
''' Incremental-Classifier Learning
Authors : Khurram Javed, Muhammad Talha Paracha
Maintainer : Khurram Javed
Lab : TUKL-SEECS R&D Lab
Email : 14besekjaved@seecs.edu.pk '''
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from resnet import *
class BasicNet1(nn.Module):
def __init__(
self, args, use_bias=False, init="kaiming", use_multi_fc=False, device=None
):
super(BasicNet1, self).__init__()
self.use_bias = use_bias
self.init = init
self.use_multi_fc = use_multi_fc
self.args = args
if(self.args.dataset=="mnist"):
self.convnet = RPS_net_mlp()
elif(self.args.dataset=="svhn"):
self.convnet = RPS_net(self.args.num_class)
elif(self.args.dataset=="cifar100"):
self.convnet = RPS_net(self.args.num_class)
elif(self.args.dataset=="omniglot"):
self.convnet = RPS_net(self.args.num_class)
elif(self.args.dataset=="celeb"):
self.convnet = resnet18()
self.classifier = None
self.n_classes = 0
self.device = device
self.cuda()
def forward(self, x):
x1, x2 = self.convnet(x)
return x1, x2
@property
def features_dim(self):
return self.convnet.out_dim
def extract(self, x):
return self.convnet(x)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
self.eval()
return self
def copy(self):
return copy.deepcopy(self)
def add_classes(self, n_classes):
if self.use_multi_fc:
self._add_classes_multi_fc(n_classes)
else:
self._add_classes_single_fc(n_classes)
self.n_classes += n_classes
def _add_classes_multi_fc(self, n_classes):
if self.classifier is None:
self.classifier = []
new_classifier = self._gen_classifier(n_classes)
name = "_clf_{}".format(len(self.classifier))
self.__setattr__(name, new_classifier)
self.classifier.append(name)
def _add_classes_single_fc(self, n_classes):
if self.classifier is not None:
weight = copy.deepcopy(self.classifier.weight.data)
if self.use_bias:
bias = copy.deepcopy(self.classifier.bias.data)
classifier = self._gen_classifier(self.n_classes + n_classes)
if self.classifier is not None:
classifier.weight.data[:self.n_classes] = weight
if self.use_bias:
classifier.bias.data[:self.n_classes] = bias
del self.classifier
self.classifier = classifier
def _gen_classifier(self, n_classes):
# torch.manual_seed(self.seed)
classifier = nn.Linear(self.convnet.out_dim, n_classes, bias=self.use_bias).cuda()
if self.init == "kaiming":
nn.init.kaiming_normal_(classifier.weight, nonlinearity="linear")
if self.use_bias:
nn.init.constant_(classifier.bias, 0.)
return classifier