-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwithout_aug_utils.py
158 lines (123 loc) · 4.98 KB
/
without_aug_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets
from torchattacks import PGD
import time
import sys
from util import TwoCropTransform, AverageMeter
from util import adjust_learning_rate, warmup_learning_rate
from networks.resnet_big import SupConResNet
from losses import SupConLoss
try:
import apex
from apex import amp, optimizers
except ImportError:
pass
def set_loader(opt):
# construct data loader
if opt.dataset == 'cifar10':
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
elif opt.dataset == 'cifar100':
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)
elif opt.dataset == 'path':
mean = eval(opt.mean)
std = eval(opt.std)
else:
raise ValueError('dataset not supported: {}'.format(opt.dataset))
normalize = transforms.Normalize(mean=mean, std=std)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
# transforms.RandomApply([
# transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
# ], p=0.8),
# transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
# normalize,
])
if opt.dataset == 'cifar10':
train_dataset = datasets.CIFAR10(root=opt.data_folder,
transform=train_transform,
download=True)
elif opt.dataset == 'cifar100':
train_dataset = datasets.CIFAR100(root=opt.data_folder,
transform=train_transform,
download=True)
elif opt.dataset == 'path':
train_dataset = datasets.ImageFolder(root=opt.data_folder,
transform=train_transform)
else:
raise ValueError(opt.dataset)
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
return train_loader
def set_model(opt):
model = SupConResNet(name=opt.model)
criterion = SupConLoss(temperature=opt.temp)
# enable synchronized Batch Normalization
if opt.syncBN:
model = apex.parallel.convert_syncbn_model(model)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
model.encoder = torch.nn.DataParallel(model.encoder,device_ids = [0,1])
model = model.cuda()
criterion = criterion.cuda()
cudnn.benchmark = True
return model, criterion
def adv_train2(train_loader, model, criterion, optimizer, epoch, opt, multi_atk, ema):
"""one epoch training"""
model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
for idx, (images, labels) in enumerate(train_loader):
data_time.update(time.time() - end)
if torch.cuda.is_available():
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
bsz = labels.shape[0]
# warm-up learning rate
warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
adv_images = multi_atk(images, labels, loss = criterion)
adv_images.append(images)
img_to_use = [7,8,9,10]
adv_images = [adv_images[i] for i in img_to_use]
view_num = len(adv_images)
adv_images = torch.cat(adv_images, dim = 0) #size -> (2*bsz*view_num, 3,32,32)
# compute loss
features = model(adv_images)
fs = torch.split(features, [bsz for i in range(view_num)], dim=0) # each f -> bsz*128
features = torch.cat([f.unsqueeze(1) for f in fs], dim=1) #torch.Size([bsz, num_view, 128(feature dim)])
if opt.method == 'SupCon':
loss = criterion(features, labels)
elif opt.method == 'SimCLR':
loss = criterion(features)
else:
raise ValueError('contrastive method not supported: {}'.
format(opt.method))
# update metric
losses.update(loss.item(), bsz)
# SGD
optimizer.zero_grad()
loss.backward()
optimizer.step()
if ema:
ema.update(model.parameters())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print info
if (idx + 1) % opt.print_freq == 0:
print('Train adv multi: [{0}][{1}/{2}]\t'
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
epoch, idx + 1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses))
sys.stdout.flush()
return losses.avg