-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstudentRKD.py
49 lines (43 loc) · 2.17 KB
/
studentRKD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import os
import argparse
import models
import models.preact_resnet
from utils import progress_bar, load_data, train, test
from torch.optim.lr_scheduler import MultiStepLR
def main():
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--teacher_depth', default=28, type=int, help='')
parser.add_argument('--teacher_width', default=1, type=float, help='')
parser.add_argument('--depth', default=28, type=int, help='')
parser.add_argument('--width', default=0.5, type=float, help='')
parser.add_argument('--hkd', default=0., type=float, help='')
parser.add_argument('--temp', default=0., type=float, help='')
parser.add_argument('--rkd', default=0., type=float, help='')
parser.add_argument('--seed', default=0, type=int, help='')
parser.add_argument('--pool3', action='store_true', help='')
args = parser.parse_args()
save = "RKD_{}-{}_teaches_{}-{}_{}_{}_{}_{}_{}".format(args.teacher_depth,args.teacher_width,args.depth,args.width,args.hkd,args.temp,args.rkd,args.pool3,args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainloader, testloader = load_data(128)
file = "checkpoint/WideResNet{}-{}_0.pth".format(args.teacher_depth,args.teacher_width)
teacher = torch.load(file)["net"].module
teacher = teacher.to(device)
net = models.preact_resnet.PreActWideResNetStart(depth=args.depth,width=args.width,num_classes=10)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
for epoch in range(200):
print('Epoch: %d' % epoch)
train(net,trainloader,scheduler, device, optimizer,teacher=teacher,lambda_hkd=args.hkd,lambda_rkd=args.rkd,temp=args.temp,classes=10,pool3_only=args.pool3)
test(net,testloader, device, save_name=save)
if __name__ == "__main__":
main()