-
Notifications
You must be signed in to change notification settings - Fork 1
/
Update.py
114 lines (103 loc) · 4.52 KB
/
Update.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn import metrics
import copy
#matplotlib.use('Agg')
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[int(self.idxs[item])]
return image, label
class LocalUpdate(object):
def __init__(self, args, dataset, idxs, tb):
self.args = args
self.tb = tb
self.loss_func = nn.CrossEntropyLoss()
# self.loss_func = nn.NLLLoss()
if args.model == 'svm':
self.ldr_train, self.ldr_test = self.args.dataset_train[idxs,:], self.args.dataset_test[idxs]
elif (args.model == 'mlp') or (args.model == 'cnn'):
self.ldr_train, self.ldr_test = self.train_val_test(dataset, list(idxs))
def train_val_test(self, dataset, idxs):
# split train, and test
idxs_train = idxs
if (self.args.dataset == 'mnist') or (self.args.dataset == 'cifar') or (self.args.dataset == 'FashionMNIST'):
idxs_test = idxs
train = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=True)
#val = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=int(len(idxs_val)/10), shuffle=True)
test = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=int(len(idxs_test)), shuffle=False)
else:
train = self.args.dataset_train[idxs]
test = self.args.dataset_test[idxs]
return train, test
def update_weights(self, net):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0)
epoch_loss = []
epoch_acc = []
for iter in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
if self.args.gpu != -1:
images, labels = images.cuda(), labels.cuda()
images, labels = autograd.Variable(images), \
autograd.Variable(labels)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()
if self.args.gpu != -1:
loss = loss.cpu()
# if self.args.verbose and batch_idx % 10 == 0:
# print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# iter, batch_idx * len(images), len(self.ldr_train.dataset),
# 100. * batch_idx / len(self.ldr_train), loss.data.item()))
self.tb.add_scalar('loss', loss.data.item())
batch_loss.append(loss.data.item())
epoch_loss.append(sum(batch_loss) / len(batch_loss))
acc, _, = self.test(net)
# print('\nLabels:', labels.data, y_pred)
# if (iter+1)%10 == 0:
# print("local epoch:", iter)
# print("acc: {}".format(acc))
epoch_acc.append(acc)
if iter == 0:
w_1st_ep = copy.deepcopy(net.state_dict())
if iter < 1:
w_2st_ep = copy.deepcopy(w_1st_ep)
if iter == 1:
w_2st_ep = copy.deepcopy(net.state_dict())
avg_loss = sum(epoch_loss) / len(epoch_loss)
avg_acc = sum(epoch_acc) / len(epoch_acc)
w = net.state_dict()
return w_1st_ep, w_2st_ep, w, avg_loss, avg_acc
def test(self, net):
loss = 0
log_probs = []
labels = []
for batch_idx, (images, labels) in enumerate(self.ldr_test):
if self.args.gpu != -1:
images, labels = images.cuda(), labels.cuda()
images, labels = autograd.Variable(images), autograd.Variable(labels)
net = net.float()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
if self.args.gpu != -1:
loss = loss.cpu()
log_probs = log_probs.cpu()
labels = labels.cpu()
y_pred = np.argmax(log_probs.data, axis=1)
acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
loss = loss.data.item()
return acc, loss