-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathrunMarabou.py
executable file
·165 lines (147 loc) · 6.57 KB
/
runMarabou.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#! /usr/bin/env python3
'''
Top contributors (to current version):
- Andrew Wu
This file is part of the Marabou project.
Copyright (c) 2017-2021 by the authors listed in the file AUTHORS
in the top-level source directory) and their institutional affiliations.
All rights reserved. See the file COPYING in the top-level source
directory for licensing information.
'''
import argparse
import numpy as np
import os
import sys
import tempfile
import pathlib
sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../"))
from maraboupy import Marabou
from maraboupy import MarabouCore
from maraboupy import MarabouUtils
import subprocess
def main():
args, unknown = arguments().parse_known_args()
query, network = createQuery(args)
if query == None:
print("Unable to create an input query!")
print("There are three options to define the benchmark:\n"
"1. Provide an input query file.\n"
"2. Provide a network and a property file.\n"
"3. Provide a network, a dataset (--dataset), an epsilon (-e), "
"target label (-t), and the index of the point in the test set (-i).")
exit(1)
marabou_binary = args.marabou_binary
if not os.access(marabou_binary, os.X_OK):
sys.exit('"{}" does not exist or is not executable'.format(marabou_binary))
temp = tempfile.NamedTemporaryFile(dir=args.temp_dir, delete=False)
name = temp.name
MarabouCore.saveQuery(query, name)
print("Running Marabou with the following arguments: ", unknown)
subprocess.run([marabou_binary] + ["--input-query={}".format(name)] + unknown )
os.remove(name)
def createQuery(args):
if args.input_query:
query = Marabou.load_query(args.input_query)
return query, None
networkPath = args.network
suffix = networkPath.split('.')[-1]
if suffix == "nnet":
network = Marabou.read_nnet(networkPath)
elif suffix == "pb":
network = Marabou.read_tf(networkPath)
elif suffix == "onnx":
network = Marabou.read_onnx(networkPath)
else:
print("The network must be in .pb, .nnet, or .onnx format!")
return None, None
if args.prop != None:
query = network.getInputQuery()
MarabouCore.loadProperty(query, args.prop)
return query, network
if args.dataset == 'mnist':
encode_mnist_linf(network, args.index, args.epsilon, args.target_label)
return network.getInputQuery(), network
elif args.dataset == 'cifar10':
encode_cifar10_linf(network, args.index, args.epsilon, args.target_label)
return network.getInputQuery(), network
else:
"""
ENCODE YOUR CUSTOMIZED PROPERTY HERE!
"""
print("No property encoded!")
return network.getInputQuery(), network
def encode_mnist_linf(network, index, epsilon, target_label):
from tensorflow.keras.datasets import mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
point = np.array(X_test[index]).flatten() / 255
print("correct label: {}".format(Y_test[index]))
for x in np.array(network.inputVars).flatten():
network.setLowerBound(x, max(0, point[x] - epsilon))
network.setUpperBound(x, min(1, point[x] + epsilon))
if target_label == -1:
print("No output constraint!")
else:
outputVars = network.outputVars[0].flatten()
for i in range(10):
if i != target_label:
network.addInequality([outputVars[i],
outputVars[target_label]],
[1, -1], 0)
return
def encode_cifar10_linf(network, index, epsilon, target_label):
import torchvision.datasets as datasets
import torchvision.transforms as transforms
cifar_test = datasets.CIFAR10('./data/cifardata/', train=False, download=True, transform=transforms.ToTensor())
X,y = cifar_test[index]
point = X.unsqueeze(0).numpy().flatten()
lb = np.zeros(3072)
ub = np.zeros(3072)
for i in range(1024):
lb[i] = max(0, point[i] - epsilon)
ub[i] = min(1, point[i] + epsilon)
for i in range(1024):
lb[1024 + i] = max(0, point[1024 + i] - epsilon)
ub[1024 + i] = min(1, point[1024 + i] + epsilon)
for i in range(1024):
lb[2048 + i] = max(0, point[2048 + i] - epsilon)
ub[2048 + i] = min(1, point[2048 + i] + epsilon)
print("correct label: {}".format(y))
if target_label == -1:
print("No output constraint!")
else:
for i in range(3072):
network.setLowerBound(i, lb[i])
network.setUpperBound(i, ub[i])
for i in range(10):
if i != target_label:
network.addInequality([network.outputVars[0][0][i],
network.outputVars[0][0][target_label]],
[1, -1], 0)
return
def arguments():
################################ Arguments parsing ##############################
parser = argparse.ArgumentParser(description="Script to run some canonical benchmarks with Marabou (e.g., ACAS benchmarks, l-inf robustness checks on mnist/cifar10).")
# benchmark
parser.add_argument('network', type=str, nargs='?', default=None,
help='The network file name, the extension can be only .pb, .nnet, and .onnx')
parser.add_argument('prop', type=str, nargs='?', default=None,
help='The property file name')
parser.add_argument('-q', '--input-query', type=str, default=None,
help='The input query file name')
parser.add_argument('--dataset', type=str, default=None,
help="the dataset (mnist,cifar10)")
parser.add_argument('-e', '--epsilon', type=float, default=0,
help='The epsilon for L_infinity perturbation')
parser.add_argument('-t', '--target-label', type=int, default=-1,
help='The target of the adversarial attack')
parser.add_argument('-i,', '--index', type=int, default=0,
help='The index of the point in the test set')
parser.add_argument('--temp-dir', type=str, default="/tmp/",
help='Temporary directory')
marabou_path = os.path.join(str(pathlib.Path(__file__).parent.absolute()),
"../build/Marabou" )
parser.add_argument('--marabou-binary', type=str, default=marabou_path,
help='The path to Marabou binary')
return parser
if __name__ == "__main__":
main()