Skip to content

Commit

Permalink
support Pytorch 1.7
Browse files Browse the repository at this point in the history
Co-authored-by: Zhouxing Shi <zhouxingshichn@gmail.com>
Co-authored-by: Huan Zhang <huan@huan-zhang.com>
Co-authored-by: Yihan Wang <wangyihan617@gmail.com>
  • Loading branch information
4 people committed Dec 24, 2020
1 parent 7e1fbf1 commit 1d7a278
Show file tree
Hide file tree
Showing 17 changed files with 165 additions and 121 deletions.
5 changes: 1 addition & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ language: python
python:
- "3.7"
install:
# fix torch's version for result check
- pip install torch==1.5.0
- pip install --editable .
- cd examples
- pip install -r requirements.txt
# fix torchvision's version to make it compatible with the torch
- pip install torchvision==0.6.0
- pip install torchvision==0.6.0 torch==1.7.0
- cd ..
script:
- cd tests
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# auto_LiRPA: Automatic Linear Relaxation based Perturbation Analysis for Neural Networks

![](https://travis-ci.com/KaidiXu/CROWN-GENERAL.svg?token=HM3jb55xV1sMRsVKBr8b&branch=master&status=started)
![](https://travis-ci.com/KaidiXu/auto_LiRPA.svg?token=HM3jb55xV1sMRsVKBr8b&branch=master&status=started)

<p align="center">
<img src="http://www.huan-zhang.com/images/upload/lirpa/auto_lirpa_2.png" width="45%" height="45%" float="left">
Expand Down
73 changes: 38 additions & 35 deletions auto_LiRPA/bound_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, *x, final_node_name=None):
for l_pre in l.input_name:
l.from_input = l.from_input or self._modules[l_pre].from_input
fv = l.forward(*inp)
if isinstance(fv, torch.Size):
if isinstance(fv, torch.Size) or isinstance(fv, tuple):
fv = torch.tensor(fv, device=self.device)
object.__setattr__(l, 'forward_value', fv)
object.__setattr__(l, 'fv', fv)
Expand Down Expand Up @@ -504,10 +504,12 @@ def get_optimized_bounds(self, x=None, aux=None, C=None, IBP=False, forward=Fals
assert l.shape[1] == 1
if i == 0 or i == iteration - 1:
print('optimal slope:', l.flatten(), scheduler.get_last_lr())
if (l > 1e-4).all(): # all lower bounds > 0, no need to optimize
break

l = l.sum()
opt.zero_grad()
l = -1 * l
l = (-1 * l) * (l < 1e-4) # only optimize the lower bounds < 0
# early stop
if last_l <= l and iteration < 100:
break
Expand All @@ -521,7 +523,9 @@ def get_optimized_bounds(self, x=None, aux=None, C=None, IBP=False, forward=Fals
return self.compute_bounds(x, aux, C, IBP, forward, method, bound_lower, bound_upper, reuse_ibp, return_A, final_node_name, average_A, new_interval)

def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, method='backward', bound_lower=True,
bound_upper=True, reuse_ibp=False, return_A=False, final_node_name=None, average_A=False, new_interval=None):
bound_upper=True, reuse_ibp=False,
return_A=False, final_node_name=None, average_A=False, new_interval=None,
return_b=False, b_dict=None):
if not bound_lower and not bound_upper:
raise ValueError('At least one of bound_lower and bound_upper in compute_bounds should be True')
A_dict = {} if return_A else None
Expand All @@ -533,18 +537,10 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met
if method == 'forward':
forward = True
root = [self._modules[name] for name in self.root_name]
batch_size = root[0].value.shape[0]
batch_size = root[0].fv.shape[0]
dim_in = 0
for i in range(len(root)):
if type(root[i]) == BoundInput:
value = root[i].forward_value = root[i].value
elif type(root[i]) == BoundParams:
value = root[i].forward_value = root[i].param
elif type(root[i]) == BoundBuffers:
value = root[i].forward_value = root[i].buffer
else:
# a detached intermediate node, which can be treated as an independent node in bound computation
value = root[i].forward_value
value = root[i].forward()
if root[i].perturbation is not None:
root[i].linear, root[i].center, root[i].aux = \
root[i].perturbation.init(value, aux=aux, forward=forward)
Expand Down Expand Up @@ -651,7 +647,8 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met

if method == 'backward':
return self._backward_general(C=C, node=final, root=root, bound_lower=bound_lower, bound_upper=bound_upper,
return_A=return_A, average_A=average_A, A_dict=A_dict)
return_A=return_A, average_A=average_A, A_dict=A_dict,
return_b=return_b, b_dict=b_dict)
elif method == 'forward':
return self._forward_general(C=C, node=final, root=root, dim_in=dim_in, concretize=True)
else:
Expand All @@ -675,8 +672,8 @@ def _IBP_loss_fusion(self, node, C):
node_linear = self._modules[node.input_name[0]]
node_start = self._modules[node_linear.input_name[0]]
if isinstance(node_linear, BoundLinear):
w = self._modules[node_linear.input_name[1]].forward_value
b = self._modules[node_linear.input_name[2]].forward_value
w = self._modules[node_linear.input_name[1]].fv
b = self._modules[node_linear.input_name[2]].fv
labels = self._modules[node_gather.input_name[1]]
if not hasattr(node_start, 'interval'):
self._IBP_general(node_start)
Expand Down Expand Up @@ -741,7 +738,8 @@ def _IBP_general(self, node=None, C=None):

return node.interval

def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True, return_A=False, average_A=False, A_dict=None):
def _backward_general(self, C=None, node=None, root=None, bound_lower=True, bound_upper=True,
return_A=False, average_A=False, A_dict=None, return_b=False, b_dict=None):
logger.debug('Backward from ({})[{}]'.format(node, node.name))

_print_time = False
Expand All @@ -762,7 +760,6 @@ def _backward_general(self, C=None, node=None, root=None, bound_lower=True, boun
node.bounded = True
batch_size, output_dim = C.shape[:2]


if not isinstance(C, eyeC) and not isinstance(C, Patches):
C = C.transpose(0, 1)
elif isinstance(C, eyeC):
Expand All @@ -783,10 +780,17 @@ def _get_A_shape(node):
return shape_A

queue = deque([node])
A_record = {}
while len(queue) > 0:
l = queue.popleft() # backward from l
l.bounded = True

if return_b:
b_dict[l.name] = {
'lower_b': lb,
'upper_b': ub
}

if l.name in self.root_name or l == root: continue

for l_pre in l.input_name: # if all the succeeds are done, then we can turn to this node in the next iteration.
Expand Down Expand Up @@ -854,7 +858,9 @@ def add_bound(node, lA, uA):
logger.debug('Backward at {}[{}], fv shape {}, {}'.format(
l, l.name, l.forward_value.shape, _get_A_shape(l)))
except: pass
A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *input_nodes)
A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *input_nodes)
if return_A:
A_record.update({l.name: A})
except:
raise Exception('Error at bound_backward of {}, {}'.format(l, l.name))

Expand All @@ -863,9 +869,7 @@ def add_bound(node, lA, uA):
if time_elapsed > 1e-3:
print(l, time_elapsed)
lb = lb + lower_b
ub = ub + upper_b

logger.debug('ub mean {}'.format(torch.mean(ub)))
ub = ub + upper_b

for i, l_pre in enumerate(l.input_name):
try: logger.debug(' {} -> {}, uA shape {}'.format(l.name, l_pre, A[i][1].shape))
Expand All @@ -887,6 +891,7 @@ def add_bound(node, lA, uA):
for i in range(len(root)):
if root[i].lA is None and root[i].uA is None: continue
this_A_dict.update({root[i].name: [root[i].lA, root[i].uA]})
this_A_dict.update(A_record)
A_dict.update({node.name: this_A_dict})

for i in range(len(root)):
Expand All @@ -905,8 +910,6 @@ def add_bound(node, lA, uA):
if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].lA, Patches):
uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None
if root[i].perturbation is not None:
logger.debug('Concretize {}[{}], fv shape {}, {}'.format(
root[i], root[i].name, root[i].forward_value.shape, _get_A_shape(root[i])))
if isinstance(root[i], BoundParams):
# add batch_size dim for weights node
lb = lb + root[i].perturbation.concretize(
Expand All @@ -921,24 +924,24 @@ def add_bound(node, lA, uA):
# FIXME to simplify
elif i < self.num_global_inputs:
if not isinstance(lA, eyeC):
lb = lb + lA.bmm(root[i].value.view(batch_size, -1, 1)).squeeze(-1) if bound_lower else None
lb = lb + lA.bmm(root[i].fv.view(batch_size, -1, 1)).squeeze(-1) if bound_lower else None
else:
lb = lb + root[i].value.view(batch_size, -1) if bound_lower else None
lb = lb + root[i].fv.view(batch_size, -1) if bound_lower else None
if not isinstance(uA, eyeC):
# FIXME looks questionable
ub = ub + uA.bmm(root[i].value.view(batch_size, -1, 1)).squeeze(-1) if bound_upper else None
ub = ub + uA.bmm(root[i].fv.view(batch_size, -1, 1)).squeeze(-1) if bound_upper else None
else:
ub = ub + root[i].value.view(batch_size, -1) if bound_upper else None
ub = ub + root[i].fv.view(batch_size, -1) if bound_upper else None
else:
if not isinstance(lA, eyeC):
lb = lb + lA.matmul(root[i].param.view(-1, 1)).squeeze(-1) if bound_lower else None
lb = lb + lA.matmul(root[i].fv.view(-1, 1)).squeeze(-1) if bound_lower else None
else:
lb = lb + root[i].param.view(1, -1) if bound_lower else None
lb = lb + root[i].fv.view(1, -1) if bound_lower else None
if not isinstance(uA, eyeC):
# FIXME looks questionable
ub = ub + uA.matmul(root[i].param.view(-1, 1)).squeeze(-1) if bound_upper else None
ub = ub + uA.matmul(root[i].fv.view(-1, 1)).squeeze(-1) if bound_upper else None
else:
ub = ub + root[i].param.view(1, -1) if bound_upper else None
ub = ub + root[i].fv.view(1, -1) if bound_upper else None

node.lower = lb.view(batch_size, *output_shape) if bound_lower else None
node.upper = ub.view(batch_size, *output_shape) if bound_upper else None
Expand Down Expand Up @@ -1019,7 +1022,7 @@ def _init_forward(self, root, dim_in):
if dim_in == 0:
raise ValueError("At least one node should have a specified perturbation")
prev_dim_in = 0
batch_size = root[0].value.shape[0]
batch_size = root[0].fv.shape[0]
for i in range(len(root)):
if root[i].perturbation is not None:
shape = root[i].linear.lw.shape
Expand Down Expand Up @@ -1121,7 +1124,7 @@ def forward(self, *inputs, **kwargs):
# inputs_scatter = inputs_scatter[0]
bounded_inputs = []
for input_s in inputs_scatter: # GPU numbers
ptb = PerturbationLpNorm(inputs[0].ptb.norm, inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2])
ptb = PerturbationLpNorm(norm=inputs[0].ptb.norm, eps=inputs[0].ptb.eps, x_L=input_s[1], x_U=input_s[2])
# bounded_inputs.append(tuple([(BoundedTensor(input_s[0][0], ptb))]))
input_s = list(input_s[0])
input_s[0] = BoundedTensor(input_s[0], ptb)
Expand All @@ -1136,7 +1139,7 @@ def forward(self, *inputs, **kwargs):
bounded_inputs = []
inputs_scatter, kwargs = self.scatter((inputs, x.ptb.x_L, x.ptb.x_U), kwargs, self.device_ids)
for input_s, kw_s in zip(inputs_scatter, kwargs): # GPU numbers
ptb = PerturbationLpNorm(x.ptb.norm, x.ptb.eps, x_L=input_s[1], x_U=input_s[2])
ptb = PerturbationLpNorm(norm=x.ptb.norm, eps=x.ptb.eps, x_L=input_s[1], x_U=input_s[2])
kw_s['x'] = list(kw_s['x'])
kw_s['x'][0] = BoundedTensor(kw_s['x'][0], ptb)
kw_s['x'] = (kw_s['x'])
Expand Down
Loading

0 comments on commit 1d7a278

Please sign in to comment.