-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcanary_attack_main.py
116 lines (95 loc) · 3.33 KB
/
canary_attack_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
from functools import partial
import os, sys
import importlib
import myPickle
import models
from utility import lr_schlr
from canary_utility import local_training
from datasets import load_dataset_classification
home_output = './results/'
loss_function = models.sparse_classification_loss
if __name__ == '__main__':
try:
setting = sys.argv[1]
id = int(sys.argv[2])
except:
print("[USAGE] setting_file id")
sys.exit(1)
output = []
C = importlib.import_module(setting)
rng_seed = C.rng_seed + id
tf.random.set_seed(rng_seed)
np.random.seed(rng_seed)
output.append(rng_seed)
if C.injection_type == 1:
from canary_attack import load_dataset, setup_model, evaluate_canary_attack, inject_canary
name = '_'.join(map(str,[C.dataset_key, C.dataset_key_shadow, C.injection_type, C.pos_w, C.batch_size_train, C.loss_threshold, C.model_id, C.canary_id, C.learning_rate_fedAVG]))
name = f'{id}-{name}'
output_path = os.path.join(home_output, name)
print(name)
# load datasets
validation, shadow, x_shape, class_num, (x_target, y_target) = load_dataset(
C.dataset_key,
C.dataset_key_shadow,
C.batch_size_test,
C.batch_size_train,
data_aug_shadow=C.data_aug_shadow,
)
output.append(x_target)
# load model and pick canary location
model, layer_idx, g_canary_shift, kernel_idx, pre_canary_layer_trainable_variables = setup_model(
C.model_id,
C.canary_id,
x_shape,
class_num
)
print("Injecting canary ...")
inj_logs, ths_reached = inject_canary(
C.max_number_of_iters,
C.batch_size_train,
model,
x_target,
shadow,
pre_canary_layer_trainable_variables,
C.opt,
loss_threshold=C.loss_threshold,
w=C.pos_w,
)
if not ths_reached:
print("Canary injection failed! Try again.")
sys.exit(1)
output.append(inj_logs)
# prepare evaluation function
test_canary_fn = partial(
evaluate_canary_attack,
target=x_target,
variables=pre_canary_layer_trainable_variables,
loss_function=loss_function,
g_canary_shift=g_canary_shift,
kernel_idx=kernel_idx,
max_num_batches_eval=C.max_num_batches_eval
)
print("Evaluation FedSGD ....")
scores_FedSGD = []
for sgd_batch_size_evaluation in C.batch_size_tests:
print(f"\tEvaluation FedSGD - batch size: {sgd_batch_size_evaluation} ... ")
validation_i, _, _ = load_dataset_classification(C.dataset_key, sgd_batch_size_evaluation, split='train', repeat=1)
score_FedSGD, failed_FedSGD = test_canary_fn(model, validation_i)
print(sgd_batch_size_evaluation, score_FedSGD)
scores_FedSGD.append( (sgd_batch_size_evaluation, (score_FedSGD, failed_FedSGD)) )
output.append(scores_FedSGD)
print("Evaluation FedAVG ....")
canary_scores_FedAVG = local_training(
model,
validation,
C.num_iter_fedAVG,
C.learning_rate_fedAVG,
loss_function,
test_canary_fn
)
output.append(canary_scores_FedAVG)
myPickle.dump(output_path, output)