-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpsgd_cifar10.py
236 lines (210 loc) · 9.25 KB
/
psgd_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import sys
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import numpy as np
import preconditioned_stochastic_gradient_descent as psgd
from tqdm import tqdm
import random
import os
from data_loaders.loaders import *
from models.resnet import ResNet18
from models.resnet_affine import ResNet18Affine
from reproduce.seeds import *
parser = argparse.ArgumentParser()
parser.add_argument("--experiment", default='cifar10', help="pick which experiment")
parser.add_argument("--stage2", default='cifar10', help="pick stage2 of experiment")
parser.add_argument("--device", default='cuda:0', help="for example, cuda:0")
parser.add_argument("--optimizer", default='PSGD_Affine', help="choices are SGD, PSGD_XMat and PSGD_UVd")
parser.add_argument("--lr_scheduler", default='cos', help="choices are stage and cos")
parser.add_argument("--num_epoch", default=200, type=int, help="how long should our full experiment be")
parser.add_argument("--epoch_concept_switch", default=201, type=int, help="when should we switch to stage2 of experiment")
parser.add_argument("--num_runs", default=1, type=int, help="how many runs")
parser.add_argument("--shortcut_connection", default=1, type=int, help="choices are 0 and 1")
parser.add_argument('--seed', default=2048, type=int, help='random seed')
parser.add_argument('--data_seed', default=1738, type=int, help='random data_seed')
parser.add_argument('--data_root', default='./data/ntga_cnn_best/', help='root of data')
args = parser.parse_args()
experiment = args.experiment
stage2 = args.stage2
num_epoch = args.num_epoch
runs = args.num_runs
epoch_concept_switch = args.epoch_concept_switch
device = torch.device(args.device)
optimizer = args.optimizer
lr_scheduler = args.lr_scheduler
shortcut_connection = bool(args.shortcut_connection)
seed = args.seed
data_seed = args.data_seed
data_root = args.data_root
print("Experiment Stage 1: \t\t\t{}".format(experiment))
print("Experiment Stage 2: \t\t\t{}".format(stage2))
print("Total Epochs: \t\t\t{}".format(num_epoch))
print("Change Experiment at Epoch: \t\t\t{}".format(epoch_concept_switch))
print("Device: \t\t\t{}".format(device))
print("Optimizer: \t\t\t{}".format(optimizer))
print("Learning rate schedular:\t{}".format(lr_scheduler))
print("With short connections: \t{}".format(shortcut_connection))
print("Seed: \t\t\t\t{}".format(seed))
print("Data Seed: \t\t\t{}".format(data_seed))
print("Data Root: \t\t\t{}".format(data_root))
set_seed(args.seed)
data_seed_l = random.sample(range(1, 50000), runs)
seed_l = random.sample(range(1, 50000), runs)
print("Seed List: \t\t\t\t{}".format(seed_l))
print("Data Seed List: \t\t\t{}".format(data_seed_l))
set_cuda(deterministic=True)
torch_version = torch.__version__
print("Torch Version: \t\t\t\t{}".format(torch_version))
# if torch_version.startswith('2'):
# exact_hessian_vector_product = False
# print("Using inexact hessian vector product: torch.compile does not support double auto.grad")
# else:
# exact_hessian_vector_product = True
exact_hessian_vector_product = True
if optimizer == 'SGD':
lr0 = 1.0 # 0.1 -> 1.0 when momentum factor = 0.9 as momentum in PSGD is the moving average of gradient
decay = 5e-4
elif 'Affine' in optimizer:
lr0 = 2e-1
if shortcut_connection:
decay = 1
else:
decay = 1e-2
else: # PSGD_XMat or PSGD_UVd
lr0 = 2e-2
if shortcut_connection:
decay = 2e-2
else:
decay = 1e-2
if shortcut_connection:
batchsize = 128
else:
batchsize = 64
def test(net, device, data_loader, criterion):
# if torch.__version__.startswith('2'):
# net = torch.compile(net)
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(data_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
accuracy = 100.0 * correct / total
return accuracy
def train(net, device, data_loader, criterion):
# if torch.__version__.startswith('2'):
# net = torch.compile(net)
net.train() # do not forget it as there is BN
total = 0
train_loss = 0
correct = 0
for batch_idx, (inputs, targets) in tqdm( enumerate(data_loader), total = len(data_loader)):
inputs, targets = inputs.to(device), targets.to(device)
def closure():
"""
Weight decaying is explicitly realized by adding L2 regularization to the loss
"""
outputs = net(inputs)
loss = criterion(outputs, targets) + sum(
[torch.sum(decay * torch.rand_like(param) * param * param) for param in net.parameters()]
)
return [loss, outputs]
loss, outputs = opt.step(closure)
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_accuracy = 100.0 * correct / total
return train_loss, train_accuracy
train_accs_l = []
test_accs_l = []
for run in range(runs):
seed_l[run]
set_seed(seed_l[run])
if optimizer == 'PSGD_Affine':
net = ResNet18Affine(shortcut_connection=True).to(device)
else:
net = ResNet18(shortcut_connection=True).to(device)
if optimizer == 'SGD':
# SGD baseline
opt = psgd.XMat(
net.parameters(),
lr_params = lr0, # note that momentum in PSGD is the moving average of gradient
momentum = 0.9, # so lr 0.1 becomes 1 when momentum factor is 0.9
preconditioner_update_probability = 0.0, # PSGD reduces to SGD when P = eye()
)
elif optimizer == 'PSGD_XMat':
# PSGD with X-shape matrix preconditioner
opt = psgd.XMat(
net.parameters(),
lr_params = lr0,
momentum = 0.9,
preconditioner_update_probability = 0.1,
exact_hessian_vector_product = exact_hessian_vector_product
)
elif optimizer == 'PSGD_Affine':
opt = psgd.Affine(
net.parameters(),
lr_params = lr0,
momentum = 0.9,
preconditioner_update_probability = 0.1,
exact_hessian_vector_product = exact_hessian_vector_product
)
else:
# PSGD with low rank approximation preconditioner
opt = psgd.UVd(
net.parameters(),
lr_params = lr0,
momentum = 0.9,
preconditioner_update_probability = 0.1,
exact_hessian_vector_product = exact_hessian_vector_product
)
# stage 1 of experiment
# please note noisy label experiment requires different training loop -- see psgd_cifar10_noisy_label.py
train_loader, test_loader = get_dataset(experiment, batchsize, data_root, seed_l[run], data_seed_l[run])
criterion = nn.CrossEntropyLoss()
num_epoch = 200
train_accs = []
test_accs = []
for epoch in range(num_epoch):
if lr_scheduler == 'cos':
opt.lr_params = lr0*(1 + math.cos(math.pi*epoch/num_epoch))/2
elif lr_scheduler == 'exp':
opt.lr_params = lr0 * (0.1 ** np.less(num_epoch, epoch).sum())
else:
# schedule the learning rate
if epoch == int(num_epoch * 0.7):
opt.lr_params *= 0.1
if epoch == int(num_epoch * 0.9):
opt.lr_params *= 0.1
if epoch == epoch_concept_switch:
# if there is a second stage of the experiment
# note noisy label dataset requires different train loop -- see psgd_cifar10_noisy_label.py
train_loader, test_loader = get_dataset(stage2, batchsize, data_root, seed_l[run], data_seed_l[run])
train_loss, train_accuracy = train(net, device, train_loader, criterion)
test_accuracy = test(net, device, test_loader, criterion)
print(
"run: {}; epoch: {}; train loss: {:.2f}; train accuracy: {:.2f}; test accuracy: {:.2f}".format(
run + 1, epoch + 1, train_loss, train_accuracy, test_accuracy
)
)
train_accs.append(train_accuracy)
test_accs.append(test_accuracy)
print("train_accuracy: {}".format(train_accs))
print("test_accuracy: {}".format(test_accs))
train_accs_l.append(train_accs)
test_accs_l.append(test_accs)
print("train_accuracy: {}".format(train_accs_l))
print("test_accuracy: {}".format(test_accs_l))