forked from cap-ntu/FedReID
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
130 lines (103 loc) · 4.85 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import time
import torch
from utils import get_optimizer, get_model
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.autograd import Variable
import copy
from optimization import Optimization
class Client():
def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride, clustering=False):
self.cid = cid
self.project_dir = project_dir
self.model_name = model_name
self.data = data
self.device = device
self.local_epoch = local_epoch
self.lr = lr
self.batch_size = batch_size
self.dataset_sizes = self.data.train_dataset_sizes[cid]
self.train_loader = self.data.train_loaders[cid]
self.model = get_model(self.data.train_class_sizes[cid], drop_rate, stride)
self.classifier = copy.deepcopy(self.model.classifier.classifier)
self.model.classifier.classifier = nn.Sequential()
self.distance = 0
self.optimization = Optimization(self.train_loader, self.device)
self.use_clustering = clustering
# print("class name size",class_names_size[cid])
def train(self, federated_model=None, use_cuda=False):
self.y_err = []
self.y_loss = []
if self.use_clustering:
print("using clustering, model is set before")
assert federated_model is None
# self.model.classifier.classifier = nn.Sequential()
federated_model = copy.deepcopy(self.model)
else:
self.model.load_state_dict(federated_model.state_dict())
self.model.classifier.classifier = self.classifier
self.old_classifier = copy.deepcopy(self.classifier)
self.model = self.model.to(self.device)
self.model.train(True)
optimizer = get_optimizer(self.model, self.lr)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
criterion = nn.CrossEntropyLoss()
since = time.time()
print('Client', self.cid, 'start training')
for epoch in range(self.local_epoch):
print('Epoch {}/{}'.format(epoch, self.local_epoch - 1))
print('-' * 10)
# scheduler.step()
running_loss = 0.0
running_corrects = 0.0
for data in self.train_loader:
inputs, labels = data
b, c, h, w = inputs.shape
if b < self.batch_size:
continue
# if use_cuda:
# inputs = Variable(inputs.cuda().detach())
# labels = Variable(labels.cuda().detach())
# else:
# inputs, labels = Variable(inputs), Variable(labels)
inputs, labels = inputs.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * b
running_corrects += float(torch.sum(preds == labels.data))
used_data_sizes = (self.dataset_sizes - self.dataset_sizes % self.batch_size)
epoch_loss = running_loss / used_data_sizes
epoch_acc = running_corrects / used_data_sizes
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
'train', epoch_loss, epoch_acc))
self.y_loss.append(epoch_loss)
self.y_err.append(1.0-epoch_acc)
time_elapsed = time.time() - since
print('Client', self.cid, ' Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
time_elapsed = time.time() - since
print('Client', self.cid, 'Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
# save_network(self.model, self.cid, 'last', self.project_dir, self.model_name, gpu_ids)
self.classifier = self.model.classifier.classifier
self.distance = self.optimization.cdw_feature_distance(federated_model, self.old_classifier, self.model)
self.model.classifier.classifier = nn.Sequential()
def generate_soft_label(self, x, regularization):
return self.optimization.kd_generate_soft_label(self.model, x, regularization)
def generate_custom_data_feature(self, inputs):
return self.optimization.generate_custom_data_feature(self.model, inputs)
def get_model(self):
return self.model
def get_data_sizes(self):
return self.dataset_sizes
def get_train_loss(self):
return self.y_loss[-1]
def get_cos_distance_weight(self):
return self.distance
def set_model(self, model):
self.model = copy.deepcopy(model)