-
Notifications
You must be signed in to change notification settings - Fork 1
/
faithfulness.py
342 lines (305 loc) · 19 KB
/
faithfulness.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import time, sys
import torch
print("Cuda is available:", torch.cuda.is_available())
from accelerate import Accelerator
import pandas as pd
import json
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoProcessor, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, AutoConfig
from PIL import Image
import random, os
from tqdm import tqdm
from read_datasets import read_data
from generation_and_prompting import *
from mm_shap_cc_shap import *
from other_faith_tests import *
from config import *
torch.cuda.empty_cache()
accelerator = Accelerator()
accelerator.free_memory()
from transformers.utils import logging
logging.set_verbosity_error()
import logging
logging.getLogger('shap').setLevel(logging.ERROR)
random.seed(42)
t1 = time.time()
c_task = sys.argv[1]
model_name = sys.argv[2]
num_samples = int(sys.argv[3])
save_json = int(sys.argv[4])
data_root = sys.argv[5]
# load model
if "mplug" in model_name:
config = AutoConfig.from_pretrained(MODELS[model_name], trust_remote_code=True)
with torch.no_grad():
model = AutoModel.from_pretrained(MODELS[model_name], attn_implementation='sdpa', torch_dtype=torch.half,
trust_remote_code=True, device_map="auto").eval()
# elif model_name == "llava_vicuna": # comment this in if you want to use quantisation for llava_vicuna and flash_attention_2
# from transformers import BitsAndBytesConfig
# # specify how to quantize the model with bitsandbytes
# quantization_config = BitsAndBytesConfig(
# # load_in_8bit=True,
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# ) # 8 just load_in_8bit=True,
# with torch.no_grad():
# model = LlavaNextForConditionalGeneration.from_pretrained(MODELS[model_name], torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
# use_flash_attention_2=True,
# quantization_config = quantization_config
# ) # .to("cuda") not needed for bitsandbytes anymore
else:
if model_name == "bakllava":
ModelClass = LlavaForConditionalGeneration
else:
ModelClass = LlavaNextForConditionalGeneration
with torch.no_grad():
model = ModelClass.from_pretrained(MODELS[model_name], torch_dtype=torch.float16,
low_cpu_mem_usage=True, #device_map="auto"
# use_flash_attention_2=True, # comment this in if you want to use flash_attention_2
).to("cuda")
# load tokenizer
if "mplug" in model_name:
tokenizer_real = AutoTokenizer.from_pretrained(MODELS[model_name])
processor = model.init_processor(tokenizer_real)
tokenizer = {"tokenizer": tokenizer_real, "processor": processor}
else:
tokenizer = AutoProcessor.from_pretrained(MODELS[model_name])
print(f"Done loading model and tokenizer after {time.time()-t1:.2f}s.")
if 'atanasova_counterfactual' in TESTS or 'turpin' in TESTS or 'lanham' in TESTS:
with torch.no_grad():
helper_model = AutoModelForCausalLM.from_pretrained(MODELS['llama2-13b-chat'], torch_dtype=torch.float16, device_map="auto", token=True)
helper_tokenizer = AutoTokenizer.from_pretrained(MODELS['llama2-13b-chat'], use_fast=False, padding_side='left')
print(f"Loaded helper model {time.time()-t1:.2f}s.")
else:
print(f"No need for helper model given the subselection of tests.")
# print(lm_generate('I enjoy walking with my cute dog.', helper_model, helper_tokenizer, max_new_tokens=max_new_tokens))
if __name__ == '__main__':
############################# run experiments on data
res_dict = {}
formatted_samples, correct_answers, wrong_answers, image_paths = [], [], [], []
accuracy, accuracy_cot = 0, 0
atanasova_counterfact_count, turpin_test_count, count, cc_shap_post_hoc_sum, cc_shap_cot_sum = 0, 0, 0, 0, 0
mm_shap_post_hoc_sum, mm_shap_expl_post_hoc_sum, mm_shap_cot_sum, mm_shap_expl_cot_sum = 0, 0, 0, 0
lanham_early_count, lanham_mistake_count, lanham_paraphrase_count, lanham_filler_count = 0, 0, 0, 0
print("Preparing data...")
if c_task in MULT_CHOICE_DATA.keys(): ###### VALSE tests
# read the valse data from the json files
images_path = f"{data_root}{MULT_CHOICE_DATA[c_task][0]}"
foils_path = f"{data_root}{MULT_CHOICE_DATA[c_task][1]}"
foils_data = read_data(c_task, foils_path, images_path, data_root)
for foil_id, foil in tqdm(foils_data.items()): # tqdm
if count + 1 > num_samples:
break
if c_task == 'mscoco':
# for everything other than VALSE: pretend like the sample was accepted by annotators
caption_fits = 3
else: # the subtask stems from VALSE data
caption_fits = foil['mturk']['caption'] # take only samples accepted by annotators
if caption_fits >= 2: # MTURK filtering! Use only valid set
test_img_path = os.path.join(images_path, foil["image_file"])
if c_task == 'mscoco':
confounder = random.sample(sorted(foils_data.items()), 1)[0][1]
test_sentences = [foil["caption"], confounder["caption"]]
else:
if c_task == 'plurals':
test_sentences = [foil["caption"][0], foil["foils"][0]]
else:
test_sentences = [foil["caption"], foil["foils"][0]]
# shuffle the order of caption and foil such that the correct answer is not always A
if random.choice([0, 1]) == 0:
formatted_sample = format_example_valse_pairwise(test_sentences[0], test_sentences[1])
correct_answer, wrong_answer = 'A', 'B'
else:
formatted_sample = format_example_valse_pairwise(test_sentences[1], test_sentences[0])
correct_answer, wrong_answer = 'B', 'A'
formatted_samples.append(formatted_sample)
correct_answers.append(correct_answer)
wrong_answers.append(wrong_answer)
image_paths.append(test_img_path)
count += 1
elif c_task in OPEN_ENDED_DATA.keys(): # open ended generation tasks
images_path = f"{data_root}{OPEN_ENDED_DATA[c_task][0]}"
qa_path = f"{data_root}{OPEN_ENDED_DATA[c_task][1]}"
vqa_data = read_data(c_task, qa_path, images_path, data_root)
for foil_id, foil in tqdm(vqa_data.items()): # tqdm
if count + 1 > num_samples:
break
test_img_path = os.path.join(images_path, foil["image_file"])
question = foil["caption"]
formatted_sample = format_example_vqa_gqa(question) # takes in question
if c_task == 'vqa':
correct_answer = foil["answers"] # there are multiple answers annotations
else:
correct_answer = foil["answer"]
wrong_answer = "impossible to give"
formatted_samples.append(formatted_sample)
correct_answers.append(correct_answer)
wrong_answers.append(wrong_answer)
image_paths.append(test_img_path)
count += 1
else:
raise NotImplementedError(f'Your specified task has no implementation: {c_task}')
print("Done preparing data. Running test...")
for k, formatted_sample, correct_answer, wrong_answer, image_path in tqdm(zip(range(len(formatted_samples)), formatted_samples, correct_answers, wrong_answers, image_paths)):
raw_image = Image.open(image_path).convert("RGB") # read image
if c_task in MULT_CHOICE_DATA.keys():
labels = LABELS['binary']
elif c_task in OPEN_ENDED_DATA.keys():
labels = None
else:
labels = LABELS[c_task]
t7 = time.time()
if "mplug" in model_name:
# compute model accuracy post-hoc
inp_ask_for_prediction = prompt_answer_with_input(formatted_sample, c_task)
prediction = vlm_predict(copy.deepcopy(inp_ask_for_prediction), raw_image, model, tokenizer, c_task, labels=labels)
# post-hoc explanation
input_pred_ask_for_expl = prompt_post_hoc_expl_with_input(formatted_sample, prediction, c_task)
input_pred_expl = vlm_generate(copy.deepcopy(input_pred_ask_for_expl), raw_image, model, tokenizer, max_new_tokens=max_new_tokens, repeat_input=True)
# for accuracy with CoT: first let the model generate the cot, then the answer.
input_ask_for_cot = prompt_cot_with_input(formatted_sample, c_task)
output_cot = vlm_generate(copy.deepcopy(input_ask_for_cot), raw_image, model, tokenizer, max_new_tokens=max_new_tokens, repeat_input=False, skip_special_tokens=True)
input_cot_ask_for_pred = prompt_answer_after_cot_with_input(output_cot, c_task, inputt=formatted_sample)
prediction_cot = vlm_predict(copy.deepcopy(input_cot_ask_for_pred), raw_image, model, tokenizer, c_task, labels=labels)
else:
# compute model accuracy post-hoc
inp_ask_for_prediction = prompt_answer_with_input(formatted_sample, c_task)
prediction = vlm_predict(inp_ask_for_prediction, raw_image, model, tokenizer, c_task, labels=labels)
# post-hoc explanation
input_pred_ask_for_expl = prompt_post_hoc_expl_with_input(formatted_sample, prediction, c_task)
input_pred_expl = vlm_generate(input_pred_ask_for_expl, raw_image, model, tokenizer, max_new_tokens=max_new_tokens, repeat_input=True)
# for accuracy with CoT: first let the model generate the cot, then the answer.
input_ask_for_cot = prompt_cot_with_input(formatted_sample, c_task)
input_cot = vlm_generate(input_ask_for_cot, raw_image, model, tokenizer, max_new_tokens=max_new_tokens, repeat_input=True)
input_cot_ask_for_pred = prompt_answer_after_cot_with_input(input_cot, c_task)
prediction_cot = vlm_predict(input_cot_ask_for_pred, raw_image, model, tokenizer, c_task, labels=labels)
accuracy_sample = evaluate_prediction(prediction, correct_answer, c_task)
accuracy += accuracy_sample
accuracy_cot_sample = evaluate_prediction(prediction_cot, correct_answer, c_task)
accuracy_cot += accuracy_cot_sample
# # post-hoc tests
if 'atanasova_counterfactual' in TESTS:
atanasova_counterfact, atanasova_counterfact_info = faithfulness_test_atanasova_etal_counterfact(formatted_sample, raw_image, prediction, model, tokenizer, c_task, helper_model, helper_tokenizer, labels)
else: atanasova_counterfact, atanasova_counterfact_info = 0, 0
if 'cc_shap-posthoc' in TESTS:
mm_score_post_hoc, mm_score_expl_post_hoc, score_post_hoc, dist_correl_ph, mse_ph, var_ph, kl_div_ph, js_div_ph, shap_plot_info_ph, tuple_shap_values_prediction = cc_shap_measure(copy.deepcopy(inp_ask_for_prediction), prediction, input_pred_ask_for_expl, raw_image, model, tokenizer, c_task, tuple_shap_values_prediction=None, expl_type='post_hoc', max_new_tokens=max_new_tokens)
else: mm_score_post_hoc, mm_score_expl_post_hoc, score_post_hoc, dist_correl_ph, mse_ph, var_ph, kl_div_ph, js_div_ph, shap_plot_info_ph = 0, 0, 0, 0, 0, 0, 0, 0, 0
# # CoT tests
if 'turpin' in TESTS:
turpin, turpin_info = faithfulness_test_turpin_etal(formatted_sample, input_cot_ask_for_pred, prediction_cot, raw_image, prediction_cot, correct_answer, wrong_answer, model, tokenizer, c_task, helper_model, helper_tokenizer, labels, max_new_tokens=max_new_tokens)
else: turpin, turpin_info = 0, 0
if 'lanham' in TESTS:
lanham_early, lanham_mistake, lanham_paraphrase, lanham_filler, lanham_early_info = faithfulness_test_lanham_etal(prediction_cot, input_cot, input_ask_for_cot, raw_image, model, tokenizer, c_task, helper_model, helper_tokenizer, labels, max_new_tokens=max_new_tokens)
else: lanham_early, lanham_mistake, lanham_paraphrase, lanham_filler, lanham_early_info = 0, 0, 0, 0, 0
if 'cc_shap-cot' in TESTS:
mm_score_cot, mm_score_expl_cot, score_cot, dist_correl_cot, mse_cot, var_cot, kl_div_cot, js_div_cot, shap_plot_info_cot, _ = cc_shap_measure(copy.deepcopy(inp_ask_for_prediction), prediction, copy.deepcopy(input_ask_for_cot), raw_image, model, tokenizer, c_task, tuple_shap_values_prediction, expl_type='cot', max_new_tokens=max_new_tokens)
else: mm_score_cot, mm_score_expl_cot, score_cot, dist_correl_cot, mse_cot, var_cot, kl_div_cot, js_div_cot, shap_plot_info_cot = 0, 0, 0, 0, 0, 0, 0, 0, 0
# aggregate results
atanasova_counterfact_count += atanasova_counterfact
cc_shap_post_hoc_sum += score_post_hoc
turpin_test_count += turpin
lanham_early_count += lanham_early
lanham_mistake_count += lanham_mistake
lanham_paraphrase_count += lanham_paraphrase
lanham_filler_count += lanham_filler
cc_shap_cot_sum += score_cot
mm_shap_post_hoc_sum += mm_score_post_hoc
mm_shap_expl_post_hoc_sum += mm_score_expl_post_hoc
mm_shap_cot_sum += mm_score_cot
mm_shap_expl_cot_sum += mm_score_expl_cot
res_dict[f"{c_task}_{model_name}_{k}"] = {
"image_path": image_path,
"sample": formatted_sample,
"correct_answer": correct_answer,
"post-hoc": {
# "inp_pred_expl": input_pred_expl, # input, prediction, expl
"prediction": prediction,
"accuracy": accuracy_sample,
"shap_plot_info_post_hoc": shap_plot_info_ph,
"cc_shap-posthoc": f"{score_post_hoc:.2f}",
"t-shap_post_hoc": f"{mm_score_post_hoc*100:.0f}",
"t-shap_expl_post_hoc": f"{mm_score_expl_post_hoc*100:.0f}",
"atanasova_counterfact": atanasova_counterfact,
"atanasova_counterfact_info": atanasova_counterfact_info,
"other_measures_post_hoc": {
"dist_correl": f"{dist_correl_ph:.2f}",
"mse": f"{mse_ph:.2f}",
"var": f"{var_ph:.2f}",
"kl_div": f"{kl_div_ph:.2f}",
"js_div": f"{js_div_ph:.2f}"
},
},
"cot": {
"inp_cot_askpred": input_cot_ask_for_pred, # input, generated cot, prompt for final answer
"pred_cot": prediction_cot, # prediction after cot
"accuracy_cot": accuracy_cot_sample,
"shap_plot_info_cot": shap_plot_info_cot,
# add plot info for the rest of the tests as well
"turpin": turpin,
"turpin_info": turpin_info,
"lanham_early": lanham_early,
"lanham_early_info": lanham_early_info,
"lanham_mistake": lanham_mistake,
"lanham_paraphrase": lanham_paraphrase,
"lanham_filler": lanham_filler,
"cc_shap-cot": f"{score_cot:.2f}",
"t-shap_cot": f"{mm_score_cot*100:.0f}",
"t-shap_expl_cot": f"{mm_score_expl_cot*100:.0f}",
"other_measures_cot": {
"dist_correl": f"{dist_correl_cot:.2f}",
"mse": f"{mse_cot:.2f}",
"var": f"{var_cot:.2f}",
"kl_div": f"{kl_div_cot:.2f}",
"js_div": f"{js_div_cot:.2f}"
},
}
}
# write results every 10 samples
if (k+1) % 1 == 0:
print(f"Ran {TESTS} on {c_task} {k+1} samples with model {model_name}. Reporting accuracy and faithfulness percentage.\n")
print(f"Accuracy % : {accuracy*100/(k+1):.2f} ")
print(f"Atanasova Counterfact % : {atanasova_counterfact_count*100/(k+1):.2f} ")
print(f"CC-SHAP post-hoc mean score : {cc_shap_post_hoc_sum/(k+1):.2f} ")
print(f"Accuracy CoT % : {accuracy_cot*100/(k+1):.2f} ")
print(f"Turpin % : {turpin_test_count*100/(k+1):.2f} ")
print(f"Lanham Early Answering % : {lanham_early_count*100/(k+1):.2f} ")
print(f"Lanham Filler % : {lanham_filler_count*100/(k+1):.2f} ")
print(f"Lanham Mistake % : {lanham_mistake_count*100/(k+1):.2f} ")
print(f"Lanham Paraphrase % : {lanham_paraphrase_count*100/(k+1):.2f} ")
print(f"CC-SHAP CoT mean score : {cc_shap_cot_sum/(k+1):.2f} ")
print(f"T-SHAP post-hoc mean score % : {mm_shap_post_hoc_sum/(k+1)*100:.2f} ")
print(f"T-SHAP expl post-hoc mean score %: {mm_shap_expl_post_hoc_sum/(k+1)*100:.2f} ")
print(f"T-SHAP CoT mean score % : {mm_shap_cot_sum/(k+1)*100:.2f} ")
print(f"T-SHAP expl CoT mean score % : {mm_shap_expl_cot_sum/(k+1)*100:.2f} ")
c = time.time()-t7
print(f"A step ran for {c // 60 % 60:.2f} minutes, {c % 60:.2f} seconds.")
if save_json and (k+1) % 10 == 0:
# save results to a json file, make results_json directory if it does not exist
if not os.path.exists('results_json'):
os.makedirs('results_json')
with open(f"results_json/{c_task}_{model_name}_{k+1}.json", 'w') as file:
json.dump(res_dict, file)
if save_json:
# save results to a json file, make results_json directory if it does not exist
if not os.path.exists('results_json'):
os.makedirs('results_json')
with open(f"results_json/{c_task}_{model_name}_{count}_final.json", 'w') as file:
json.dump(res_dict, file)
print(f"Ran {TESTS} on {c_task} {count} samples with model {model_name}. Reporting accuracy and faithfulness percentage.\n")
print(f"Accuracy % : {accuracy*100/count:.2f} ")
print(f"Atanasova Counterfact % : {atanasova_counterfact_count*100/count:.2f} ")
print(f"CC-SHAP post-hoc mean score : {cc_shap_post_hoc_sum/count:.2f} ")
print(f"Accuracy CoT % : {accuracy_cot*100/count:.2f} ")
print(f"Turpin % : {turpin_test_count*100/count:.2f} ")
print(f"Lanham Early Answering % : {lanham_early_count*100/count:.2f} ")
print(f"Lanham Filler % : {lanham_filler_count*100/count:.2f} ")
print(f"Lanham Mistake % : {lanham_mistake_count*100/count:.2f} ")
print(f"Lanham Paraphrase % : {lanham_paraphrase_count*100/count:.2f} ")
print(f"CC-SHAP CoT mean score : {cc_shap_cot_sum/count:.2f} ")
print(f"T-SHAP post-hoc mean score % : {mm_shap_post_hoc_sum/count*100:.2f} ")
print(f"T-SHAP expl post-hoc mean score %: {mm_shap_expl_post_hoc_sum/count*100:.2f} ")
print(f"T-SHAP CoT mean score % : {mm_shap_cot_sum/count*100:.2f} ")
print(f"T-SHAP expl CoT mean score % : {mm_shap_expl_cot_sum/count*100:.2f} ")
c = time.time()-t1
print(f"\nThis script ran for {c // 86400:.2f} days, {c // 3600 % 24:.2f} hours, {c // 60 % 60:.2f} minutes, {c % 60:.2f} seconds.")