-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathgrad_ec.py
188 lines (147 loc) · 7.43 KB
/
grad_ec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
## Erase and Check via approximate L0 optimization
## argmax_{mask} loss(prompt \odot mask) s.t. mak has least possible sparsity
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import argparse, time, json, os
import math
from defenses import progress_bar
def is_harmful(prompt, model, tokenizer,
num_iters=50, init_temp=0.5, reg_const=0.1):
"""
Wrap around the adversarial masking procedure as a "is_harmful" function
"""
tokens = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
model_output = model(tokens)
output_class = model_output[0].argmax().item()
if output_class == 0:
# if prompt is initially labelled harmful
# return harmful
return True
else:
# if prompt is initially labelled safe
# apply adversarial masking to find it can be
# converted to be labelled harmful
masked_prompt = grad_ec(prompt, model, tokenizer,
num_iters=num_iters, init_temp=init_temp, reg_const=reg_const)
tokens = torch.tensor(tokenizer.encode(masked_prompt)).unsqueeze(0).to(device)
model_output = model(tokens)
#print("Masked Prediction: " + ("safe" if model_output[0].argmax().item() == 1 else "harmful"))
output_class = model_output[0].argmax().item()
return (output_class == 0)
def grad_ec(prompt, model, tokenizer, word_embeddings, num_iters=50):
"""
This function implements GradEC -- a gradient based approximation of the erase-and-check procedure.
It optimizes a mask over the input prompt to erase tokens and check if the resulting
prompt is harmful or not. If the resulting sequence is harmful, it returns True, else False. It also
returns the erased prompt.
Args:
prompt: Input prompt
model: Safety classifier
tokenizer: Tokenizer for the safety classifier
word_embeddings: Embedding layer of the safety classifier
num_iters: Number of iterations
Returns:
harmful: Whether the input prompt is harmful or not
masked_prompt: input prompt with some tokens masked out
"""
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize adversarial sequence and convert to tokens
tokens = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
# Initialize to all zeros
mask_logits = torch.ones_like(tokens).float()
mask_logits = torch.nn.Parameter(mask_logits.requires_grad_().to(device))
optimizer = torch.optim.SGD(params=[mask_logits], lr=10.0, momentum=0.0, weight_decay=1e-2)
for i in range(num_iters):
optimizer.zero_grad()
mask_sigmoid = torch.sigmoid(mask_logits)
# Erased prompt
binary_mask = (mask_sigmoid >= 0.5).long()
# binary_mask = (mask_sigmoid >= 0.5).float()
masked_tokens = binary_mask * tokens
# masked_tokens = torch.round(binary_mask * tokens).long()
masked_prompt = tokenizer.decode((masked_tokens)[0][1:-1])
# If erased prompt is harmful, return True
model_output = model(torch.tensor(tokenizer.encode(masked_prompt)).unsqueeze(0).to(device)) # Evaluates model on erased prompt
if model_output.logits[0].argmax().item() == 0:
return True, masked_prompt
embeddings = word_embeddings(tokens)
embeddings = mask_sigmoid.unsqueeze(2) * embeddings + (1 - mask_sigmoid.unsqueeze(2)) * word_embeddings(0 * tokens)
# Class 0 is harmful
output = model(inputs_embeds=embeddings, labels=torch.tensor([0]).to(device))
loss = output.loss
loss.backward()
optimizer.step()
mask_sigmoid = torch.sigmoid(mask_logits)
# binary_mask = (mask_sigmoid >= 0.5).float()
# masked_tokens = torch.round(binary_mask * tokens).long()
binary_mask = (mask_sigmoid >= 0.5).long()
masked_tokens = binary_mask * tokens
masked_prompt = tokenizer.decode((masked_tokens)[0][1:-1])
# If erased prompt is harmful, return True
model_output = model(torch.tensor(tokenizer.encode(masked_prompt)).unsqueeze(0).to(device)) # Evaluates model on erased prompt
if model_output.logits[0].argmax().item() == 0:
return True, masked_prompt
else:
return False, masked_prompt
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Adversarial masks for the safety classifier.')
parser.add_argument('--prompts_file', type=str, default='data/adversarial_prompts_t_20.txt', help='File containing prompts')
parser.add_argument('--num_iters', type=int, default=50, help='Number of iterations')
parser.add_argument('--model_wt_path', type=str, default='models/distilbert_suffix.pt', help='Path to model weights')
parser.add_argument('--results_file', type=str, default='results/grad_ec_results.json', help='Path to results file')
args = parser.parse_args()
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
# Load model weights
model_wt_path = args.model_wt_path
model.load_state_dict(torch.load(model_wt_path, map_location=device))
model.to(device)
model.eval()
prompts_file = args.prompts_file
num_iters = args.num_iters
results_file = args.results_file
print('\n* * * * * Experiment Parameters * * * * *')
print('Prompts file: ' + prompts_file)
print('Number of iterations: ' + str(num_iters))
print('Model weights: ' + model_wt_path)
print('* * * * * * * * * * * * * * * * * * * * *\n')
# Load prompts
prompts = []
with open(prompts_file, 'r') as f:
for line in f:
prompts.append(line.strip())
print("Loaded " + str(len(prompts)) + " prompts.")
list_of_bools = []
start_time = time.time()
# Open results file and load previous results JSON as a dictionary
results_dict = {}
# Create results file if it does not exist
if not os.path.exists(results_file):
with open(results_file, 'w') as f:
json.dump(results_dict, f)
with open(results_file, 'r') as f:
results_dict = json.load(f)
for num_done, input_prompt in enumerate(prompts):
decision, masked_prompt = grad_ec(input_prompt, model, tokenizer, model.distilbert.embeddings.word_embeddings,
num_iters=num_iters)
list_of_bools.append(decision)
#print("ORIGINAL PROMPT: " + input_prompt)
#print("Masked Prompt: " + masked_prompt)
#print("Prediction: " + ("harmful" if decision else "safe"))
#input('Press enter to continue')
percent_harmful = (sum(list_of_bools) / len(list_of_bools)) * 100.
current_time = time.time()
elapsed_time = current_time - start_time
time_per_prompt = elapsed_time / (num_done + 1)
print(" Checking safety... " + progress_bar((num_done + 1) / len(prompts)) \
+ f' Detected harmful = {percent_harmful:5.1f}%' \
+ f' Time/prompt = {time_per_prompt:5.1f}s', end="\r")
print("")
# Save results
results_dict[str(dict(num_iters = num_iters))] = dict(percent_harmful = percent_harmful, time_per_prompt = time_per_prompt)
with open(results_file, 'w') as f:
json.dump(results_dict, f, indent=2)