-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
108 lines (92 loc) · 4.07 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from tqdm import tqdm
from hyperparameter import *
from encrypt import encrypt_image
import numpy as np
use_cuda = torch.cuda.is_available()
function = lambda x:encrypt_image(x)
def val(device, net, valloader):
correct = 0
total = 0
with torch.no_grad():
for data in valloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
_, predict = outputs.max(1)
total += labels.size(0)
correct += (predict == labels).sum().item()
# print the accuracy
print('Accuracy of the network on the val images: %.3f %%' % (
100 * correct / total))
def train(model, device, criterion, trainloader, optimizer, epochs, epoch_count=None):
model.train()
for epoch in range(epochs):
epoch += 1
pbar = tqdm(trainloader, total=len(trainloader))
train_loss_all = .0
epoch_print = epoch if epoch_count is None else epoch_count
for batch_id, (inputs, labels) in enumerate(pbar):
if use_cuda:
inputs = inputs.to(device)
labels = labels.to(device)
inputs = torch.autograd.Variable(inputs)
labels = torch.autograd.Variable(labels)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, predict = outputs.max(1)
train_loss_all += loss.data
train_loss = train_loss_all/(batch_id+1)
pbar.set_description("poch: {%d} - loss: {%5f} " % (epoch_print, train_loss))
return
def prepare_data():
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#transforms.Lambda(function)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#transforms.Lambda(function)
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Batch_Size, shuffle=True, num_workers=2)
valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
valloader = torch.utils.data.DataLoader(valset, batch_size=Batch_Size, shuffle=False, num_workers=2)
return trainloader, valloader
def prepare_morphed_data():
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Lambda(function)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Lambda(function)
])
if Dataset == "CIFAR10":
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
else:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Batch_Size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=Batch_Size, shuffle=False, num_workers=2)
return trainloader, valloader
def datareverse(data):
trans_mat = np.load("mul_matrix.npy")
reverse_m_np = np.linalg.inv(trans_mat)
reverse_m = torch.from_numpy(reverse_m_np).float()
for i in range(3):
data[i] = data[i].mm(reverse_m)
return data