-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
174 lines (154 loc) · 5.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import math
import time
import torch
import torch.nn as nn
import numpy as np
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.utils import get_spec_matrix, logger
from auto_LiRPA.perturbations import PerturbationLpNorm
from datasets import load_data
from parser import parse_args
from bab import bab_gradnorm
from monotonicity import monotonicity
from lipbab.LipBaB import lipbab
from convert import conv_to_linear
from global_lip import global_lip
from utils import prepare_model, evaluate_clean
def compute_complete_jacobian_bounds(model, x, labels):
assert args.bab
time_begin = time.time()
assert x.size(0) == 1 # batch size must be 1
c = torch.ones(1, 1, 1).to(x) # For backward graph
c_forward = get_spec_matrix(x, labels, args.num_classes) # For forward graph
# Check all classes
ans = 0
bounds = []
for j in range(args.num_classes):
grad_start = torch.zeros(1, 1, args.num_classes).to(x)
grad_start[0, 0, j] = 1
model(x, grad_start, final_node_name=model.forward_final_name)
model(x, grad_start)
ret = -bab_gradnorm(model, x, grad_start, c=-c, c_forward=c_forward,
opt_forward=(len(bounds)==0), args=args, bab=False)
if args.norm == 2:
ret = math.sqrt(ret)
bounds.append(ret)
bounds = torch.tensor(bounds)
sort_label = torch.argsort(-bounds)
time_remaining = args.timeout - (time.time() - time_begin)
time_per_class = time_remaining / len(sort_label)
for j in sort_label:
time_begin_class = time.time()
grad_start = torch.zeros(1, 1, args.num_classes).to(x)
grad_start[0, 0, j] = 1
model(x, grad_start, final_node_name=model.forward_final_name)
model(x, grad_start)
time_remaining = args.timeout - (time.time() - time_begin)
timeout = min(
time_per_class - (time.time() - time_begin_class), time_remaining)
ret = -bab_gradnorm(
model, x, grad_start,
c=-c, c_forward=c_forward, args=args, timeout=timeout)
if args.norm == 2:
ret = math.sqrt(ret)
print(f'class {j}, ret {ret}\n')
bounds[j] = ret
if time.time() - time_begin >= args.timeout:
break
print(f'Worst class {sort_label[0]}->{torch.argmax(bounds)},',
f'label {labels.item()}')
print(bounds)
ans = bounds.max()
return ans
def lirpa_local_lipschitz(model, data, labels, data_lb, data_ub, args=None):
ptb = PerturbationLpNorm(norm=args.norm, x_L=data_lb, x_U=data_ub)
x = data = BoundedTensor(data, ptb)
if not args.bab:
return model.compute_jacobian_bounds(x, labels=labels)
else:
return compute_complete_jacobian_bounds(model, x, labels)
def local_lipschitz(args, model, loader, eps=None):
data_max, data_min, std = loader.data_max, loader.data_min, loader.std
if args.device == 'cuda':
data_min, data_max, std = data_min.cuda(), data_max.cuda(), std.cuda()
model.eval()
avg = 0
begin = time.time()
eps = (eps / std).view(1, -1, *([1]*(data_min.ndim - 2)))
indices = range(args.start, args.start + args.num_examples)
for i, idx in enumerate(indices):
data, labels = loader.dataset[idx]
data = data.unsqueeze(0)
labels = torch.tensor([labels])
print(f'Example {i}: labels {labels}')
data, labels = data.to(args.device), labels.to(args.device)
data_lb = torch.max(data - eps, data_min)
data_ub = torch.min(data + eps, data_max)
instance_begin = time.time()
if args.method == 'global':
ans = global_lip(model, args=args)
elif args.method == 'lipbab':
ans = lipbab(model, data, labels, data_lb, data_ub, args=args)
else:
ans = lirpa_local_lipschitz(
model, data, labels, data_lb, data_ub, args=args)
if ans is not None:
print(ans)
avg += ans
print('time', time.time() - instance_begin)
print('\n\n')
avg_lip = avg / len(indices)
avg_time = (time.time() - begin) / len(indices)
print(f'avg_lip {avg_lip:.2f} avg_time {avg_time:.2f}')
if __name__ == '__main__':
args = parse_args()
logger.info('Arguments: %s', args)
model_ori = prepare_model(args)
logger.info(f'Model structure: \n{str(model_ori)}'.format())
# Batch size must be 1 to verify gradients
dummy_input, train_data, test_data = load_data(
args, args.data, batch_size=1, test_batch_size=1)
if args.cnn_to_mlp:
model_ori = conv_to_linear(model_ori, dummy_input.shape)
dummy_input = dummy_input.view(dummy_input.size(0), -1)
logger.info('CNN converted to MLP: \n%s', model_ori)
model_ori.to(args.device)
dummy_input = dummy_input.to(args.device)
dummy_output = model_ori(dummy_input)
logger.info('Converting the original model')
conv_mode = 'patches'
for layer in model_ori._modules.values():
if isinstance(layer, nn.Conv2d) and layer.stride[0] != 1:
logger.info(
'Using matrix mode due to convolutional layers with stride != 1')
conv_mode = 'matrix'
bound_opts = {
'optimize_bound_args': {
'ob_iteration': args.ob_iteration,
'ob_lr_decay': args.ob_lr_decay,
'ob_lr': args.ob_lr,
'ob_no_float64_last_iter': True,
},
'sparse_intermediate_bounds': True,
'sparse_conv_intermediate_bounds': True,
'sparse_intermediate_bounds_with_ibp': True,
'sparse_features_alpha': False,
'sparse_spec_alpha': False,
'conv_mode': conv_mode,
'lip_method': args.method,
}
if args.method == 'recurjac':
bound_opts['recurjac'] = True
model = BoundedModule(
model_ori, dummy_input, bound_opts=bound_opts, device=args.device)
if args.norm != np.inf:
logger.warning('Using norm other than inf is not recommended.')
model.augment_gradient_graph(dummy_input, norm=args.norm)
print('\n\n\n')
logger.info('Starting computing...')
if args.mono:
monotonicity(args, model, test_data)
elif args.clean:
evaluate_clean(model_ori, test_data)
else:
local_lipschitz(args, model, test_data, eps=args.eps)